Compare commits
1 Commits
main
...
585afa5f6f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
585afa5f6f
|
1126
Cargo.lock
generated
1126
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -10,7 +10,7 @@ color-eyre = "0.6.3"
|
||||
directories = "6.0"
|
||||
futures = "0.3"
|
||||
human-panic = "2.0"
|
||||
genai = "0.5"
|
||||
genai = "0.4.3"
|
||||
irc = "1.1"
|
||||
serde_json = "1.0"
|
||||
tracing = "0.1"
|
||||
@@ -18,7 +18,7 @@ tracing-subscriber = "0.3"
|
||||
|
||||
[dependencies.nix]
|
||||
version = "0.30.1"
|
||||
features = ["fs", "resource"]
|
||||
features = ["fs"]
|
||||
|
||||
[dependencies.clap]
|
||||
version = "4.5"
|
||||
@@ -45,8 +45,7 @@ features = [
|
||||
]
|
||||
|
||||
[dev-dependencies]
|
||||
rstest = "~0.26"
|
||||
serial_test = "3.3"
|
||||
rstest = "0.24"
|
||||
tempfile = "3.13"
|
||||
|
||||
[dev-dependencies.cargo-husky]
|
||||
|
||||
@@ -2,8 +2,6 @@ edition = "2024"
|
||||
style_edition = "2024"
|
||||
comment_width = 100
|
||||
format_code_in_doc_comments = true
|
||||
format_macro_bodies = true
|
||||
format_macro_matchers = true
|
||||
imports_granularity = "Crate"
|
||||
imports_layout = "HorizontalVertical"
|
||||
wrap_comments = true
|
||||
|
||||
155
src/chat.rs
155
src/chat.rs
@@ -1,8 +1,3 @@
|
||||
//! Handles interaction with IRC.
|
||||
//!
|
||||
//! Each instance of [`Chat`] handles a single connection to an IRC
|
||||
//! server.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use color_eyre::{Result, eyre::WrapErr};
|
||||
@@ -12,66 +7,50 @@ use irc::client::prelude::{Client, Command, Config as IRCConfig, Message};
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{Level, event, instrument};
|
||||
|
||||
use crate::{CommandDir, Event, EventManager, LLMHandle, plugin};
|
||||
use crate::{Event, EventManager, LLMHandle, plugin};
|
||||
|
||||
/// Chat struct that is used to interact with IRC chat.
|
||||
#[derive(Debug)]
|
||||
pub struct Chat {
|
||||
/// The actual IRC [`irc::client`](client).
|
||||
client: Client,
|
||||
/// Handle to the directory that *may* contain command scripts.
|
||||
command_dir: Option<CommandDir>,
|
||||
/// Event manager for handling plugin interaction.
|
||||
event_manager: Arc<EventManager>,
|
||||
/// Handle for whichever LLM is being used.
|
||||
llm_handle: LLMHandle, // FIXME: This needs to be thread safe, and shared, etc.
|
||||
}
|
||||
|
||||
// Need: owners, channels, username, nick, server, password
|
||||
#[instrument]
|
||||
pub async fn new(
|
||||
settings: &MainConfig,
|
||||
handle: &LLMHandle,
|
||||
manager: Arc<EventManager>,
|
||||
) -> Result<Chat> {
|
||||
// Going to just assign and let the irc library handle errors for now, and
|
||||
// add my own checking if necessary.
|
||||
let port: u16 = settings.get("port")?;
|
||||
let channels: Vec<String> = settings.get("channels").wrap_err("No channels provided.")?;
|
||||
|
||||
event!(Level::INFO, "Channels = {:?}", channels);
|
||||
|
||||
let config = IRCConfig {
|
||||
server: settings.get_string("server").ok(),
|
||||
nickname: settings.get_string("nickname").ok(),
|
||||
port: Some(port),
|
||||
username: settings.get_string("username").ok(),
|
||||
use_tls: settings.get_bool("use_tls").ok(),
|
||||
channels,
|
||||
..IRCConfig::default()
|
||||
};
|
||||
|
||||
event!(Level::INFO, "IRC connection starting...");
|
||||
|
||||
Ok(Chat {
|
||||
client: Client::from_config(config).await?,
|
||||
llm_handle: handle.clone(),
|
||||
event_manager: manager,
|
||||
})
|
||||
}
|
||||
|
||||
impl Chat {
|
||||
// Need: owners, channels, username, nick, server, password rather than reading
|
||||
// the config values directly.
|
||||
/// Creates a new [`Chat`].
|
||||
#[instrument]
|
||||
pub async fn new(
|
||||
settings: &MainConfig,
|
||||
handle: &LLMHandle,
|
||||
manager: Arc<EventManager>,
|
||||
) -> Result<Chat> {
|
||||
// Going to just assign and let the irc library handle errors for now, and
|
||||
// add my own checking if necessary.
|
||||
let port: u16 = settings.get("port")?;
|
||||
let channels: Vec<String> = settings.get("channels").wrap_err("No channels provided.")?;
|
||||
|
||||
event!(Level::INFO, "Channels = {:?}", channels);
|
||||
|
||||
let config = IRCConfig {
|
||||
server: settings.get_string("server").ok(),
|
||||
nickname: settings.get_string("nickname").ok(),
|
||||
port: Some(port),
|
||||
username: settings.get_string("username").ok(),
|
||||
use_tls: settings.get_bool("use_tls").ok(),
|
||||
channels,
|
||||
..IRCConfig::default()
|
||||
};
|
||||
|
||||
let commands_dir = if let Ok(path) = settings.get_string("command-path") {
|
||||
Some(CommandDir::new(path))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
event!(Level::INFO, "IRC connection starting...");
|
||||
|
||||
Ok(Chat {
|
||||
client: Client::from_config(config).await?,
|
||||
command_dir: commands_dir,
|
||||
llm_handle: handle.clone(),
|
||||
event_manager: manager,
|
||||
})
|
||||
}
|
||||
|
||||
/// Drives the event loop for the chat.
|
||||
pub async fn run(&mut self, mut command_in: mpsc::Receiver<plugin::PluginMsg>) -> Result<()> {
|
||||
pub async fn run(&mut self, mut command_in: mpsc::Receiver<plugin::Plugin>) -> Result<()> {
|
||||
self.client.identify()?;
|
||||
|
||||
let mut stream = self.client.stream()?;
|
||||
@@ -90,7 +69,7 @@ impl Chat {
|
||||
command = command_in.recv() => {
|
||||
event!(Level::INFO, "Received command {:#?}", command);
|
||||
match command {
|
||||
Some(plugin::PluginMsg::SendMessage {channel, message} ) => {
|
||||
Some(plugin::Plugin::SendMessage {channel, message} ) => {
|
||||
// Now to pass on the message.
|
||||
event!(Level::INFO, "Trying to send to channel.");
|
||||
self.client.send_privmsg(&channel, &message).wrap_err("Couldn't send to channel")?;
|
||||
@@ -120,64 +99,24 @@ impl Chat {
|
||||
|
||||
// Only handle PRIVMSG for now.
|
||||
if let Command::PRIVMSG(channel, msg) = &message.command {
|
||||
// Check it's a command.
|
||||
if let Some((cmd, args)) = command_str(msg) {
|
||||
// Command handling time.
|
||||
// Just preserve the original behavior for now.
|
||||
if msg.starts_with("!gem") {
|
||||
let mut llm_response = self.llm_handle.send_request(msg).await?;
|
||||
|
||||
match cmd {
|
||||
// Just preserve the original behavior for now.
|
||||
"!gem" => {
|
||||
let stripped_msg = msg.strip_prefix("!gem").unwrap_or(msg);
|
||||
let mut llm_response = self.llm_handle.send_request(stripped_msg).await?;
|
||||
event!(Level::INFO, "Asked: {message}");
|
||||
event!(Level::INFO, "Response: {llm_response}");
|
||||
|
||||
event!(Level::INFO, "Asked: {message}");
|
||||
event!(Level::INFO, "Response: {llm_response}");
|
||||
// Keep responses to one line for now.
|
||||
llm_response.retain(|c| c != '\n' && c != '\r');
|
||||
|
||||
// Keep responses to one line for now.
|
||||
llm_response.retain(|c| c != '\n' && c != '\r');
|
||||
// TODO: Make this configurable.
|
||||
llm_response.truncate(500);
|
||||
|
||||
// TODO: Make this configurable.
|
||||
llm_response.truncate(500);
|
||||
|
||||
event!(Level::INFO, "Sending {llm_response} to channel {channel}");
|
||||
self.client.send_privmsg(channel, llm_response)?;
|
||||
}
|
||||
|
||||
_ => {
|
||||
if let Some(cmd_dir) = &self.command_dir {
|
||||
// Strip '!'
|
||||
let cmd_name = &cmd[1..];
|
||||
match cmd_dir.run_command(cmd_name, args).await {
|
||||
Ok(res) => {
|
||||
let output = std::str::from_utf8(&res)?.to_string();
|
||||
self.client.send_privmsg(channel, output)?;
|
||||
}
|
||||
Err(e) => {
|
||||
// Log the error but don't crash, and maybe don't even tell the
|
||||
// user unless we're sure it
|
||||
// was meant to be a command?
|
||||
// For now, let's just log it.
|
||||
event!(
|
||||
Level::DEBUG,
|
||||
"Command execution failed or not found: {}",
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
event!(Level::INFO, "Sending {llm_response} to channel {channel}");
|
||||
self.client.send_privmsg(channel, llm_response)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn command_str(cmd: &str) -> Option<(&str, &str)> {
|
||||
if !cmd.starts_with('!') {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(cmd.split_once(' ').unwrap_or((cmd, "")))
|
||||
}
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
//! Commands that are associated with external processes (commands).
|
||||
//!
|
||||
//! Process based plugins are just an assortment of executable files in
|
||||
//! a provided directory. They are given arguments, and the response from
|
||||
//! them is expected on stdout.
|
||||
// Commands that are associated with external processes (commands).
|
||||
|
||||
use std::{
|
||||
path::{Path, PathBuf},
|
||||
@@ -14,14 +10,12 @@ use color_eyre::{Result, eyre::eyre};
|
||||
use tokio::{fs::try_exists, process::Command, time::timeout};
|
||||
use tracing::{Level, event};
|
||||
|
||||
/// Handle containing information about the directory containing commands.
|
||||
#[derive(Debug)]
|
||||
pub struct CommandDir {
|
||||
command_path: PathBuf,
|
||||
}
|
||||
|
||||
impl CommandDir {
|
||||
/// Register a path containing commands.
|
||||
pub fn new(command_path: impl AsRef<Path>) -> Self {
|
||||
event!(
|
||||
Level::INFO,
|
||||
@@ -33,7 +27,6 @@ impl CommandDir {
|
||||
}
|
||||
}
|
||||
|
||||
/// Look for a command. If it exists Ok(path) is returned.
|
||||
async fn find_command(&self, name: impl AsRef<Path>) -> Result<String> {
|
||||
let path = self.command_path.join(name.as_ref());
|
||||
|
||||
@@ -50,8 +43,6 @@ impl CommandDir {
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the given [`command_name`]. It should exist in the directory provided as
|
||||
/// the command_path.
|
||||
pub async fn run_command(
|
||||
&self,
|
||||
command_name: impl AsRef<str>,
|
||||
@@ -74,7 +65,6 @@ impl CommandDir {
|
||||
}
|
||||
}
|
||||
|
||||
/// [`run_command`] but with a timeout.
|
||||
pub async fn run_command_with_timeout(
|
||||
&self,
|
||||
command_name: impl AsRef<str>,
|
||||
|
||||
@@ -1,19 +1,13 @@
|
||||
//! Internal representations of incoming events.
|
||||
|
||||
use irc::proto::{Command, Message};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Represents an event. Probably from IRC.
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub struct Event {
|
||||
/// Who is the message from?
|
||||
from: String,
|
||||
/// What is the message?
|
||||
message: String,
|
||||
}
|
||||
|
||||
impl Event {
|
||||
/// Creates a new message.
|
||||
pub fn new(from: impl Into<String>, msg: impl Into<String>) -> Self {
|
||||
Self {
|
||||
from: from.into(),
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//! Handler for events to and from IPC, and process plugins.
|
||||
|
||||
use std::{collections::VecDeque, path::Path, sync::Arc};
|
||||
|
||||
use color_eyre::Result;
|
||||
@@ -11,14 +9,12 @@ use tokio::{
|
||||
};
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::{event::Event, plugin::PluginMsg};
|
||||
use crate::{event::Event, plugin::Plugin};
|
||||
|
||||
// Hard coding for now. Maybe make this a parameter to new.
|
||||
const EVENT_BUF_MAX: usize = 1000;
|
||||
|
||||
/// Manager for communication with plugins.
|
||||
///
|
||||
/// Keeps events in a ring buffer to track a certain amount of history.
|
||||
// Manager for communication with plugins.
|
||||
#[derive(Debug)]
|
||||
pub struct EventManager {
|
||||
announce: broadcast::Sender<String>, // Everything broadcasts here.
|
||||
@@ -26,7 +22,6 @@ pub struct EventManager {
|
||||
}
|
||||
|
||||
impl EventManager {
|
||||
/// Create a new [`EventManager``].
|
||||
pub fn new() -> Result<Self> {
|
||||
let (announce, _) = broadcast::channel(100);
|
||||
|
||||
@@ -36,7 +31,6 @@ impl EventManager {
|
||||
})
|
||||
}
|
||||
|
||||
/// Broadcast an event to every subscribed listener.
|
||||
pub async fn broadcast(&self, event: &Event) -> Result<()> {
|
||||
let msg = serde_json::to_string(event)? + "\n";
|
||||
|
||||
@@ -55,10 +49,7 @@ impl EventManager {
|
||||
}
|
||||
|
||||
// NB: This assumes it has exclusive control of the FIFO.
|
||||
/// Opens a fifo at [`path`]. This is where some plugins can send response events
|
||||
/// to. The messages MUST be formatted in JSON and match one of the possible
|
||||
/// [`PluginMsg`](plugin messages).
|
||||
pub async fn start_fifo<P>(path: &P, command_tx: mpsc::Sender<PluginMsg>) -> Result<()>
|
||||
pub async fn start_fifo<P>(path: &P, command_tx: mpsc::Sender<Plugin>) -> Result<()>
|
||||
where
|
||||
P: AsRef<Path> + NixPath + ?Sized,
|
||||
{
|
||||
@@ -74,7 +65,7 @@ impl EventManager {
|
||||
|
||||
while reader.read_line(&mut line).await? > 0 {
|
||||
// Now handle the command.
|
||||
let cmd: PluginMsg = serde_json::from_str(&line)?;
|
||||
let cmd: Plugin = serde_json::from_str(&line)?;
|
||||
info!("Command received: {:?}.", cmd);
|
||||
command_tx.send(cmd).await?;
|
||||
line.clear();
|
||||
@@ -82,8 +73,6 @@ impl EventManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Start a UNIX socket that will provide broadcast messages to any client that opens
|
||||
/// the socket for listening.
|
||||
pub async fn start_listening(self: Arc<Self>, broadcast_path: impl AsRef<Path>) {
|
||||
let listener = UnixListener::bind(broadcast_path).unwrap();
|
||||
|
||||
@@ -104,7 +93,6 @@ impl EventManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Send any events queued up to the [`stream`].
|
||||
async fn send_events(&self, stream: UnixStream) -> Result<()> {
|
||||
let mut writer = stream;
|
||||
|
||||
@@ -328,7 +316,7 @@ mod tests {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||
|
||||
// Write a command to the FIFO
|
||||
let cmd = PluginMsg::SendMessage {
|
||||
let cmd = Plugin::SendMessage {
|
||||
channel: "#test".to_string(),
|
||||
message: "hello".to_string(),
|
||||
};
|
||||
@@ -350,7 +338,7 @@ mod tests {
|
||||
.expect("channel closed");
|
||||
|
||||
match received {
|
||||
PluginMsg::SendMessage { channel, message } => {
|
||||
Plugin::SendMessage { channel, message } => {
|
||||
assert_eq!(channel, "#test");
|
||||
assert_eq!(message, "hello");
|
||||
}
|
||||
@@ -374,15 +362,15 @@ mod tests {
|
||||
|
||||
// Write multiple commands
|
||||
let commands = vec![
|
||||
PluginMsg::SendMessage {
|
||||
Plugin::SendMessage {
|
||||
channel: "#chan1".to_string(),
|
||||
message: "first".to_string(),
|
||||
},
|
||||
PluginMsg::SendMessage {
|
||||
Plugin::SendMessage {
|
||||
channel: "#chan2".to_string(),
|
||||
message: "second".to_string(),
|
||||
},
|
||||
PluginMsg::SendMessage {
|
||||
Plugin::SendMessage {
|
||||
channel: "#chan3".to_string(),
|
||||
message: "third".to_string(),
|
||||
},
|
||||
@@ -407,7 +395,7 @@ mod tests {
|
||||
.expect("channel closed");
|
||||
|
||||
match first {
|
||||
PluginMsg::SendMessage { channel, message } => {
|
||||
Plugin::SendMessage { channel, message } => {
|
||||
assert_eq!(channel, "#chan1");
|
||||
assert_eq!(message, "first");
|
||||
}
|
||||
@@ -419,7 +407,7 @@ mod tests {
|
||||
.expect("channel closed");
|
||||
|
||||
match second {
|
||||
PluginMsg::SendMessage { channel, message } => {
|
||||
Plugin::SendMessage { channel, message } => {
|
||||
assert_eq!(channel, "#chan2");
|
||||
assert_eq!(message, "second");
|
||||
}
|
||||
@@ -431,7 +419,7 @@ mod tests {
|
||||
.expect("channel closed");
|
||||
|
||||
match third {
|
||||
PluginMsg::SendMessage { channel, message } => {
|
||||
Plugin::SendMessage { channel, message } => {
|
||||
assert_eq!(channel, "#chan3");
|
||||
assert_eq!(message, "third");
|
||||
}
|
||||
@@ -461,7 +449,7 @@ mod tests {
|
||||
let tx = pipe::OpenOptions::new().open_sender(&path).unwrap();
|
||||
let mut tx = tokio::io::BufWriter::new(tx);
|
||||
|
||||
let cmd = PluginMsg::SendMessage {
|
||||
let cmd = Plugin::SendMessage {
|
||||
channel: "#first".to_string(),
|
||||
message: "batch1".to_string(),
|
||||
};
|
||||
@@ -477,7 +465,7 @@ mod tests {
|
||||
.expect("channel closed");
|
||||
|
||||
match first {
|
||||
PluginMsg::SendMessage { channel, message } => {
|
||||
Plugin::SendMessage { channel, message } => {
|
||||
assert_eq!(channel, "#first");
|
||||
assert_eq!(message, "batch1");
|
||||
}
|
||||
@@ -494,7 +482,7 @@ mod tests {
|
||||
let tx = pipe::OpenOptions::new().open_sender(&fifo_path).unwrap();
|
||||
let mut tx = tokio::io::BufWriter::new(tx);
|
||||
|
||||
let cmd = PluginMsg::SendMessage {
|
||||
let cmd = Plugin::SendMessage {
|
||||
channel: "#second".to_string(),
|
||||
message: "batch2".to_string(),
|
||||
};
|
||||
@@ -509,7 +497,7 @@ mod tests {
|
||||
.expect("channel closed");
|
||||
|
||||
match second {
|
||||
PluginMsg::SendMessage { channel, message } => {
|
||||
Plugin::SendMessage { channel, message } => {
|
||||
assert_eq!(channel, "#second");
|
||||
assert_eq!(message, "batch2");
|
||||
}
|
||||
@@ -536,7 +524,7 @@ mod tests {
|
||||
let tx = pipe::OpenOptions::new().open_sender(&fifo_path).unwrap();
|
||||
let mut tx = tokio::io::BufWriter::new(tx);
|
||||
|
||||
let cmd1 = PluginMsg::SendMessage {
|
||||
let cmd1 = Plugin::SendMessage {
|
||||
channel: "#test".to_string(),
|
||||
message: "first".to_string(),
|
||||
};
|
||||
@@ -549,7 +537,7 @@ mod tests {
|
||||
// Write whitespace line
|
||||
tx.write_all(b" \n").await.unwrap();
|
||||
|
||||
let cmd2 = PluginMsg::SendMessage {
|
||||
let cmd2 = Plugin::SendMessage {
|
||||
channel: "#test".to_string(),
|
||||
message: "second".to_string(),
|
||||
};
|
||||
@@ -565,7 +553,7 @@ mod tests {
|
||||
.expect("channel closed");
|
||||
|
||||
match first {
|
||||
PluginMsg::SendMessage { channel, message } => {
|
||||
Plugin::SendMessage { channel, message } => {
|
||||
assert_eq!(channel, "#test");
|
||||
assert_eq!(message, "first");
|
||||
}
|
||||
|
||||
15
src/lib.rs
15
src/lib.rs
@@ -1,5 +1,4 @@
|
||||
#![warn(missing_docs)]
|
||||
#![doc = include_str!("../README.md")]
|
||||
// Robotnik libraries
|
||||
|
||||
use std::{os::unix::fs, sync::Arc};
|
||||
|
||||
@@ -17,8 +16,6 @@ pub mod plugin;
|
||||
pub mod qna;
|
||||
pub mod setup;
|
||||
|
||||
pub use chat::Chat;
|
||||
pub use command::CommandDir;
|
||||
pub use event::Event;
|
||||
pub use event_manager::EventManager;
|
||||
pub use qna::LLMHandle;
|
||||
@@ -28,9 +25,7 @@ const DEFAULT_INSTRUCT: &str =
|
||||
be sent in a single IRC response according to the specification. Keep answers to
|
||||
500 characters or less.";
|
||||
|
||||
/// Initialize all logging facilities.
|
||||
///
|
||||
/// This should cause a panic if there's a failure.
|
||||
// NB: Everything should fail if logging doesn't start properly.
|
||||
async fn init_logging() {
|
||||
better_panic::install();
|
||||
setup_panic!();
|
||||
@@ -42,10 +37,6 @@ async fn init_logging() {
|
||||
tracing::subscriber::set_global_default(subscriber).unwrap();
|
||||
}
|
||||
|
||||
/// Sets up and runs the main event loop.
|
||||
///
|
||||
/// Should return an error if it's recoverable, but could panic if something
|
||||
/// is particularly bad.
|
||||
pub async fn run() -> Result<()> {
|
||||
init_logging().await;
|
||||
info!("Starting up.");
|
||||
@@ -78,7 +69,7 @@ pub async fn run() -> Result<()> {
|
||||
let ev_manager = Arc::new(EventManager::new()?);
|
||||
let ev_manager_clone = Arc::clone(&ev_manager);
|
||||
|
||||
let mut c = Chat::new(&config, &handle, Arc::clone(&ev_manager)).await?;
|
||||
let mut c = chat::new(&config, &handle, Arc::clone(&ev_manager)).await?;
|
||||
|
||||
let (from_plugins, to_chat) = mpsc::channel(100);
|
||||
|
||||
|
||||
@@ -1,32 +1,13 @@
|
||||
//! Plugin command definitions.
|
||||
|
||||
// Dear future me: If you forget the JSON translations in the future you'll
|
||||
// thank me for the comment overkill.
|
||||
|
||||
use std::fmt::Display;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Message types accepted from plugins.
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub enum PluginMsg {
|
||||
/// Plugin message indicating the bot should send a [`message`] to [`channel`].
|
||||
/// {
|
||||
/// "SendMessage": {
|
||||
/// "channel": "channel_name",
|
||||
/// "message": "your message here"
|
||||
/// }
|
||||
///
|
||||
/// }
|
||||
SendMessage {
|
||||
/// The IRC channel to send the [`message`] to.
|
||||
channel: String,
|
||||
/// The [`message`] to send.
|
||||
message: String,
|
||||
},
|
||||
pub enum Plugin {
|
||||
SendMessage { channel: String, message: String },
|
||||
}
|
||||
|
||||
impl Display for PluginMsg {
|
||||
impl Display for Plugin {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::SendMessage { channel, message } => {
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//! Handles communication with a genai compatible LLM.
|
||||
|
||||
use color_eyre::Result;
|
||||
use futures::StreamExt;
|
||||
use genai::{
|
||||
@@ -10,11 +8,8 @@ use genai::{
|
||||
};
|
||||
use tracing::info;
|
||||
|
||||
// NB: Docs are quick and dirty as this might move into a plugin.
|
||||
|
||||
// Represents an LLM completion source.
|
||||
// FIXME: Clone is probably temporary.
|
||||
/// Struct containing information about the LLM.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct LLMHandle {
|
||||
chat_request: ChatRequest,
|
||||
@@ -23,7 +18,6 @@ pub struct LLMHandle {
|
||||
}
|
||||
|
||||
impl LLMHandle {
|
||||
/// Create a new handle.
|
||||
pub fn new(
|
||||
api_key: String,
|
||||
_base_url: impl AsRef<str>,
|
||||
@@ -50,7 +44,6 @@ impl LLMHandle {
|
||||
})
|
||||
}
|
||||
|
||||
/// Send a chat message to the LLM with the response being returned as a [`String`].
|
||||
pub async fn send_request(&mut self, message: impl Into<String>) -> Result<String> {
|
||||
let mut req = self.chat_request.clone();
|
||||
let client = self.client.clone();
|
||||
|
||||
75
src/setup.rs
75
src/setup.rs
@@ -1,19 +1,11 @@
|
||||
//! Handles configuration for the bot.
|
||||
//!
|
||||
//! Both command line, and configuration file options are handled here.
|
||||
|
||||
use clap::Parser;
|
||||
use color_eyre::{
|
||||
Result,
|
||||
eyre::{OptionExt, WrapErr},
|
||||
};
|
||||
use color_eyre::{Result, eyre::WrapErr};
|
||||
use config::Config;
|
||||
use directories::{BaseDirs, ProjectDirs};
|
||||
use directories::ProjectDirs;
|
||||
use std::path::PathBuf;
|
||||
use tracing::{Level, event, info, instrument};
|
||||
use tracing::{info, instrument};
|
||||
|
||||
// TODO: use [clap(long, short, help_heading = Some(section))]
|
||||
/// Struct of potential arguments.
|
||||
#[derive(Clone, Debug, Parser)]
|
||||
#[command(about, version)]
|
||||
pub struct Args {
|
||||
@@ -31,14 +23,13 @@ pub struct Args {
|
||||
|
||||
/// Root directory for file based command structure.
|
||||
#[arg(long)]
|
||||
pub command_path: Option<String>,
|
||||
pub command_dir: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
/// Instructions to the model on how to behave.
|
||||
pub instruct: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
/// Name of the model to use. E.g. 'deepseek-chat'
|
||||
pub model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
@@ -69,57 +60,21 @@ pub struct Args {
|
||||
/// IRC Username
|
||||
pub username: Option<String>,
|
||||
|
||||
#[arg(long = "no-tls")]
|
||||
#[arg(long)]
|
||||
/// Whether or not to use TLS when connecting to the IRC server.
|
||||
pub use_tls: Option<bool>,
|
||||
}
|
||||
|
||||
/// Handle for interacting with the bot configuration.
|
||||
pub struct Setup {
|
||||
/// Handle for the configuration file options.
|
||||
pub config: Config,
|
||||
}
|
||||
|
||||
#[instrument]
|
||||
/// Initialize a new [`Setup`] instance.
|
||||
///
|
||||
/// This reads the settings file which becomes the bot's default configuration.
|
||||
/// These settings shall be overridden by any command line options.
|
||||
pub async fn init() -> Result<Setup> {
|
||||
// Get arguments. These overrule configuration file, and environment
|
||||
// variables if applicable.
|
||||
let args = Args::parse();
|
||||
|
||||
let settings = make_config(args)?;
|
||||
|
||||
Ok(Setup { config: settings })
|
||||
}
|
||||
|
||||
/// Resolves a path, expanding `~` to the home directory.
|
||||
///
|
||||
/// If the path does not start with `~`, it is returned as is.
|
||||
pub fn resolve_path(path_str: &str) -> Result<PathBuf> {
|
||||
event!(Level::WARN, "resolve_path called with {path_str}");
|
||||
if let Some(stripped) = path_str.strip_prefix("~") {
|
||||
let base_dirs = BaseDirs::new().ok_or_eyre("Unable to expand '~'.")?;
|
||||
event!(
|
||||
Level::DEBUG,
|
||||
"home_dir() decided on {}",
|
||||
base_dirs.home_dir().display()
|
||||
);
|
||||
let relative = stripped
|
||||
.strip_prefix(std::path::MAIN_SEPARATOR_STR)
|
||||
.unwrap_or(stripped);
|
||||
return Ok(base_dirs.home_dir().join(relative));
|
||||
}
|
||||
|
||||
Ok(PathBuf::from(path_str))
|
||||
}
|
||||
|
||||
/// Create a configuration object from arguments.
|
||||
///
|
||||
/// This is exposed for testing purposes.
|
||||
pub fn make_config(args: Args) -> Result<Config> {
|
||||
// Use default config location unless specified.
|
||||
let config_location: PathBuf = if let Some(ref path) = args.config_file {
|
||||
path.to_owned()
|
||||
@@ -133,29 +88,25 @@ pub fn make_config(args: Args) -> Result<Config> {
|
||||
|
||||
info!("Starting.");
|
||||
|
||||
Config::builder()
|
||||
.add_source(config::File::with_name(&config_location.to_string_lossy()).required(true))
|
||||
let settings = Config::builder()
|
||||
.add_source(config::File::with_name(&config_location.to_string_lossy()).required(false))
|
||||
.add_source(config::Environment::with_prefix("BOT"))
|
||||
// Doing all of these overrides provides a unified access point for options,
|
||||
// but a derive macro could do this a bit better if this becomes too large.
|
||||
.set_override_option("api-key", args.api_key.clone())?
|
||||
.set_override_option("base-url", args.base_url.clone())?
|
||||
.set_override_option("chroot-dir", args.chroot_dir.clone())?
|
||||
.set_override_option(
|
||||
"command-path",
|
||||
// A path expansion is a panic situation so just unwrap() is fine.
|
||||
args.command_path
|
||||
.map(|p| resolve_path(&p).unwrap().to_string_lossy().to_string()),
|
||||
)?
|
||||
.set_override_option("command-path", args.command_dir.clone())?
|
||||
.set_override_option("model", args.model.clone())?
|
||||
.set_override_option("nick-password", args.nick_password.clone())?
|
||||
.set_override_option("instruct", args.instruct.clone())?
|
||||
.set_override_option("channels", args.channels.clone())?
|
||||
.set_override_option("server", args.server.clone())?
|
||||
.set_override_option("port", args.port.clone())?
|
||||
.set_override_option("port", args.port.clone())? // FIXME: Make this a default here not in clap.
|
||||
.set_override_option("nickname", args.nickname.clone())?
|
||||
.set_override_option("username", args.username.clone())?
|
||||
.set_override_option("use-tls", args.use_tls)?
|
||||
.set_override_option("use_tls", args.use_tls)?
|
||||
.build()
|
||||
.wrap_err("Couldn't read configuration settings.")
|
||||
.wrap_err("Couldn't read configuration settings.")?;
|
||||
|
||||
Ok(Setup { config: settings })
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ fn create_command(dir: &Path, name: &str, script: &str) {
|
||||
fs::set_permissions(&path, Permissions::from_mode(0o755)).unwrap();
|
||||
}
|
||||
|
||||
/// Parse a bot message like "!weather 07008" into (command_name, argument)
|
||||
/// Parse a bot message like "!weather 73135" into (command_name, argument)
|
||||
fn parse_bot_message(message: &str) -> Option<(&str, &str)> {
|
||||
if !message.starts_with('!') {
|
||||
return None;
|
||||
@@ -41,12 +41,12 @@ echo "Weather for $1: Sunny, 72°F"
|
||||
);
|
||||
|
||||
let cmd_dir = CommandDir::new(temp.path());
|
||||
let message = "!weather 10096";
|
||||
let message = "!weather 73135";
|
||||
|
||||
// Parse the message
|
||||
let (command_name, arg) = parse_bot_message(message).unwrap();
|
||||
assert_eq!(command_name, "weather");
|
||||
assert_eq!(arg, "10096");
|
||||
assert_eq!(arg, "73135");
|
||||
|
||||
// Find and run the command
|
||||
let result = cmd_dir.run_command(command_name, arg).await;
|
||||
@@ -54,7 +54,7 @@ echo "Weather for $1: Sunny, 72°F"
|
||||
assert!(result.is_ok());
|
||||
let bytes = result.unwrap();
|
||||
let output = String::from_utf8_lossy(&bytes);
|
||||
assert!(output.contains("Weather for 10096"));
|
||||
assert!(output.contains("Weather for 73135"));
|
||||
assert!(output.contains("Sunny"));
|
||||
}
|
||||
|
||||
@@ -253,7 +253,7 @@ echo "Why did the robot go on vacation? To recharge!"
|
||||
#[tokio::test]
|
||||
async fn test_non_bot_message_ignored() {
|
||||
// Messages not starting with ! should be ignored
|
||||
let messages = ["hello world", "weather 10096", "?help", "/command", ""];
|
||||
let messages = ["hello world", "weather 73135", "?help", "/command", ""];
|
||||
|
||||
for message in messages {
|
||||
assert!(
|
||||
|
||||
@@ -1,556 +0,0 @@
|
||||
use robotnik::setup::{Args, make_config};
|
||||
use serial_test::serial;
|
||||
use std::{fs, path::PathBuf};
|
||||
use tempfile::TempDir;
|
||||
|
||||
/// Helper to create a temporary config file
|
||||
fn create_config_file(dir: &TempDir, content: &str) -> PathBuf {
|
||||
let config_path = dir.path().join("config.toml");
|
||||
fs::write(&config_path, content).unwrap();
|
||||
config_path
|
||||
}
|
||||
|
||||
/// Helper to parse config using environment and config file
|
||||
async fn parse_config_from_file(config_path: &PathBuf) -> config::Config {
|
||||
config::Config::builder()
|
||||
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
|
||||
.build()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_setup_make_config_overrides() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
let config_content = "\
|
||||
api-key = \"file-key\"
|
||||
model = \"file-model\"
|
||||
port = 6667
|
||||
";
|
||||
let config_path = create_config_file(&temp, config_content);
|
||||
|
||||
// Construct Args with overrides
|
||||
let args = Args {
|
||||
api_key: Some("cli-key".to_string()),
|
||||
base_url: None, /* Should fail if required and not in file/env? No, base-url is optional
|
||||
* in args */
|
||||
chroot_dir: None,
|
||||
command_path: None,
|
||||
instruct: None,
|
||||
model: None, // Should fallback to file
|
||||
channels: None,
|
||||
config_file: Some(config_path),
|
||||
server: None, // Should use default or file? Args has default "irc.libera.chat"
|
||||
port: Some("9999".to_string()),
|
||||
nickname: None,
|
||||
nick_password: None,
|
||||
username: None,
|
||||
use_tls: None,
|
||||
};
|
||||
|
||||
let config = make_config(args).expect("Failed to make config");
|
||||
|
||||
// Check overrides
|
||||
assert_eq!(config.get_string("api-key").unwrap(), "cli-key");
|
||||
assert_eq!(config.get_string("port").unwrap(), "9999");
|
||||
assert_eq!(config.get_int("port").unwrap(), 9999);
|
||||
|
||||
// Check fallback to file
|
||||
assert_eq!(config.get_string("model").unwrap(), "file-model");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_config_file_loads_all_settings() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
let config_content = "\
|
||||
api-key = \"test-api-key-123\"
|
||||
base-url = \"https://api.test.com\"
|
||||
chroot-dir = \"/test/chroot\"
|
||||
command-path = \"/test/commands\"
|
||||
model = \"test-model\"
|
||||
instruct = \"Test instructions\"
|
||||
server = \"test.irc.server\"
|
||||
port = 6667
|
||||
channels = [\"#test1\", \"#test2\"]
|
||||
username = \"testuser\"
|
||||
nickname = \"testnick\"
|
||||
use-tls = false
|
||||
";
|
||||
|
||||
let config_path = create_config_file(&temp, config_content);
|
||||
let config = parse_config_from_file(&config_path).await;
|
||||
|
||||
// Verify all settings are loaded correctly
|
||||
assert_eq!(config.get_string("api-key").unwrap(), "test-api-key-123");
|
||||
assert_eq!(
|
||||
config.get_string("base-url").unwrap(),
|
||||
"https://api.test.com"
|
||||
);
|
||||
assert_eq!(config.get_string("chroot-dir").unwrap(), "/test/chroot");
|
||||
assert_eq!(config.get_string("command-path").unwrap(), "/test/commands");
|
||||
assert_eq!(config.get_string("model").unwrap(), "test-model");
|
||||
assert_eq!(config.get_string("instruct").unwrap(), "Test instructions");
|
||||
assert_eq!(config.get_string("server").unwrap(), "test.irc.server");
|
||||
assert_eq!(config.get_int("port").unwrap(), 6667);
|
||||
|
||||
let channels: Vec<String> = config.get("channels").unwrap();
|
||||
assert_eq!(channels, vec!["#test1", "#test2"]);
|
||||
|
||||
assert_eq!(config.get_string("username").unwrap(), "testuser");
|
||||
assert_eq!(config.get_string("nickname").unwrap(), "testnick");
|
||||
assert_eq!(config.get_bool("use-tls").unwrap(), false);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_config_file_partial_settings() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
// Only provide required settings
|
||||
let config_content = "\
|
||||
api-key = \"minimal-key\"
|
||||
base-url = \"https://minimal.api.com\"
|
||||
model = \"minimal-model\"
|
||||
server = \"minimal.server\"
|
||||
port = 6697
|
||||
channels = [\"#minimal\"]
|
||||
";
|
||||
|
||||
let config_path = create_config_file(&temp, config_content);
|
||||
let config = parse_config_from_file(&config_path).await;
|
||||
|
||||
// Verify required settings are loaded
|
||||
assert_eq!(config.get_string("api-key").unwrap(), "minimal-key");
|
||||
assert_eq!(
|
||||
config.get_string("base-url").unwrap(),
|
||||
"https://minimal.api.com"
|
||||
);
|
||||
assert_eq!(config.get_string("model").unwrap(), "minimal-model");
|
||||
|
||||
// Verify optional settings are not present
|
||||
assert!(config.get_string("chroot-dir").is_err());
|
||||
assert!(config.get_string("instruct").is_err());
|
||||
assert!(config.get_string("username").is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_config_with_environment_variables() {
|
||||
// NOTE: This test documents a limitation in setup.rs
|
||||
// setup.rs uses Environment::with_prefix("BOT") without a separator
|
||||
// This means BOT_API_KEY maps to "api_key", NOT "api-key"
|
||||
// Since config.toml uses kebab-case, environment variables won't override properly
|
||||
// This is a known issue in the current implementation
|
||||
|
||||
let temp = TempDir::new().unwrap();
|
||||
let config_content = "\
|
||||
api_key = \"file-api-key\"
|
||||
base_url = \"https://file.api.com\"
|
||||
model = \"file-model\"
|
||||
";
|
||||
|
||||
let config_path = create_config_file(&temp, config_content);
|
||||
|
||||
// Set environment variables (with BOT_ prefix as setup.rs uses)
|
||||
unsafe {
|
||||
std::env::set_var("BOT_API_KEY", "env-api-key");
|
||||
std::env::set_var("BOT_MODEL", "env-model");
|
||||
}
|
||||
|
||||
let config = config::Config::builder()
|
||||
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
|
||||
.add_source(config::Environment::with_prefix("BOT"))
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
// Environment variables should override file settings (when using underscore keys)
|
||||
assert_eq!(config.get_string("api_key").unwrap(), "env-api-key");
|
||||
assert_eq!(config.get_string("model").unwrap(), "env-model");
|
||||
// File setting should be used when no env var
|
||||
assert_eq!(
|
||||
config.get_string("base_url").unwrap(),
|
||||
"https://file.api.com"
|
||||
);
|
||||
|
||||
// Cleanup
|
||||
unsafe {
|
||||
std::env::remove_var("BOT_API_KEY");
|
||||
std::env::remove_var("BOT_MODEL");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_command_line_overrides_config_file() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
let config_content = "\
|
||||
api-key = \"file-api-key\"
|
||||
base-url = \"https://file.api.com\"
|
||||
model = \"file-model\"
|
||||
server = \"file.server\"
|
||||
port = 6667
|
||||
channels = [\"#file\"]
|
||||
nickname = \"filenick\"
|
||||
username = \"fileuser\"
|
||||
";
|
||||
|
||||
let config_path = create_config_file(&temp, config_content);
|
||||
|
||||
// Simulate command-line arguments overriding config file
|
||||
let config = config::Config::builder()
|
||||
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
|
||||
.set_override_option("api-key", Some("cli-api-key".to_string()))
|
||||
.unwrap()
|
||||
.set_override_option("model", Some("cli-model".to_string()))
|
||||
.unwrap()
|
||||
.set_override_option("server", Some("cli.server".to_string()))
|
||||
.unwrap()
|
||||
.set_override_option("nickname", Some("clinick".to_string()))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
// Command-line values should override file settings
|
||||
assert_eq!(config.get_string("api-key").unwrap(), "cli-api-key");
|
||||
assert_eq!(config.get_string("model").unwrap(), "cli-model");
|
||||
assert_eq!(config.get_string("server").unwrap(), "cli.server");
|
||||
assert_eq!(config.get_string("nickname").unwrap(), "clinick");
|
||||
|
||||
// Non-overridden values should come from file
|
||||
assert_eq!(
|
||||
config.get_string("base-url").unwrap(),
|
||||
"https://file.api.com"
|
||||
);
|
||||
assert_eq!(config.get_string("username").unwrap(), "fileuser");
|
||||
assert_eq!(config.get_int("port").unwrap(), 6667);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_command_line_overrides_environment_and_file() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
let config_content = "\
|
||||
api_key = \"file-api-key\"
|
||||
model = \"file-model\"
|
||||
base_url = \"https://file.api.com\"
|
||||
";
|
||||
|
||||
let config_path = create_config_file(&temp, config_content);
|
||||
|
||||
// Set environment variable
|
||||
unsafe {
|
||||
std::env::set_var("BOT_API_KEY", "env-api-key");
|
||||
}
|
||||
|
||||
// Build config with all three sources
|
||||
let config = config::Config::builder()
|
||||
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
|
||||
.add_source(config::Environment::with_prefix("BOT"))
|
||||
.set_override_option("api_key", Some("cli-api-key".to_string()))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
// Command-line should win over both environment and file
|
||||
assert_eq!(config.get_string("api_key").unwrap(), "cli-api-key");
|
||||
|
||||
// Cleanup
|
||||
unsafe {
|
||||
std::env::remove_var("BOT_API_KEY");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_precedence_order() {
|
||||
// Test: CLI > Environment > Config File > Defaults
|
||||
// Using underscore keys to match how setup.rs actually works
|
||||
let temp = TempDir::new().unwrap();
|
||||
let config_content = "\
|
||||
api_key = \"file-key\"
|
||||
base_url = \"https://file-url.com\"
|
||||
model = \"file-model\"
|
||||
server = \"file-server\"
|
||||
";
|
||||
|
||||
let config_path = create_config_file(&temp, config_content);
|
||||
|
||||
// Set environment variables
|
||||
unsafe {
|
||||
std::env::set_var("BOT_BASE_URL", "https://env-url.com");
|
||||
std::env::set_var("BOT_MODEL", "env-model");
|
||||
}
|
||||
|
||||
let config = config::Config::builder()
|
||||
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
|
||||
.add_source(config::Environment::with_prefix("BOT"))
|
||||
.set_override_option("model", Some("cli-model".to_string()))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
// CLI overrides everything
|
||||
assert_eq!(config.get_string("model").unwrap(), "cli-model");
|
||||
|
||||
// Environment overrides file
|
||||
assert_eq!(
|
||||
config.get_string("base_url").unwrap(),
|
||||
"https://env-url.com"
|
||||
);
|
||||
|
||||
// File is used when no env or CLI
|
||||
assert_eq!(config.get_string("api_key").unwrap(), "file-key");
|
||||
assert_eq!(config.get_string("server").unwrap(), "file-server");
|
||||
|
||||
// Cleanup
|
||||
unsafe {
|
||||
std::env::remove_var("BOT_BASE_URL");
|
||||
std::env::remove_var("BOT_MODEL");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_boolean_use_tls_setting() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
|
||||
// Test with use-tls = true (kebab-case as in config.toml)
|
||||
let config_content_true = r#"
|
||||
use-tls = true
|
||||
"#;
|
||||
let config_path = create_config_file(&temp, config_content_true);
|
||||
let config = parse_config_from_file(&config_path).await;
|
||||
assert_eq!(config.get_bool("use-tls").unwrap(), true);
|
||||
|
||||
// Test with use-tls = false
|
||||
let config_content_false = r#"
|
||||
use-tls = false
|
||||
"#;
|
||||
let config_path = create_config_file(&temp, config_content_false);
|
||||
let config = parse_config_from_file(&config_path).await;
|
||||
assert_eq!(config.get_bool("use-tls").unwrap(), false);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_use_tls_naming_inconsistency() {
|
||||
// This test documents a bug: setup.rs uses "use_tls" (underscore)
|
||||
// but config.toml uses "use-tls" (kebab-case)
|
||||
let temp = TempDir::new().unwrap();
|
||||
let config_content = r#"
|
||||
use-tls = true
|
||||
"#;
|
||||
|
||||
let config_path = create_config_file(&temp, config_content);
|
||||
|
||||
// Build config the way setup.rs does it
|
||||
let config = config::Config::builder()
|
||||
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
|
||||
// setup.rs line 119 uses "use_tls" (underscore) instead of "use-tls" (kebab)
|
||||
.set_override_option("use_tls", Some(false))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
// This should read from the override (false), not the file (true)
|
||||
// But due to the naming mismatch, it might not work as expected
|
||||
// The config file uses "use-tls" but the override uses "use_tls"
|
||||
|
||||
// With kebab-case (matches config.toml)
|
||||
assert_eq!(config.get_bool("use-tls").unwrap(), true);
|
||||
|
||||
// With underscore (matches setup.rs override)
|
||||
assert_eq!(config.get_bool("use_tls").unwrap(), false);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_channels_as_array() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
let config_content = "\
|
||||
channels = [\"#chan1\", \"#chan2\", \"#chan3\"]
|
||||
";
|
||||
|
||||
let config_path = create_config_file(&temp, config_content);
|
||||
let config = parse_config_from_file(&config_path).await;
|
||||
|
||||
let channels: Vec<String> = config.get("channels").unwrap();
|
||||
assert_eq!(channels.len(), 3);
|
||||
assert_eq!(channels[0], "#chan1");
|
||||
assert_eq!(channels[1], "#chan2");
|
||||
assert_eq!(channels[2], "#chan3");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_channels_override_from_cli() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
let config_content = "\
|
||||
channels = [\"#file1\", \"#file2\"]
|
||||
";
|
||||
|
||||
let config_path = create_config_file(&temp, config_content);
|
||||
|
||||
let cli_channels = vec![
|
||||
"#cli1".to_string(),
|
||||
"#cli2".to_string(),
|
||||
"#cli3".to_string(),
|
||||
];
|
||||
|
||||
let config = config::Config::builder()
|
||||
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
|
||||
.set_override_option("channels", Some(cli_channels.clone()))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let channels: Vec<String> = config.get("channels").unwrap();
|
||||
assert_eq!(channels, cli_channels);
|
||||
assert_eq!(channels.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_port_as_integer() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
let config_content = r#"
|
||||
port = 6697
|
||||
"#;
|
||||
|
||||
let config_path = create_config_file(&temp, config_content);
|
||||
let config = parse_config_from_file(&config_path).await;
|
||||
|
||||
// Port should be readable as both integer and string
|
||||
assert_eq!(config.get_int("port").unwrap(), 6697);
|
||||
assert_eq!(config.get_string("port").unwrap(), "6697");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_port_override_from_cli_as_string() {
|
||||
// setup.rs passes port as Option<String> from clap
|
||||
let temp = TempDir::new().unwrap();
|
||||
let config_content = r#"
|
||||
port = 6667
|
||||
"#;
|
||||
|
||||
let config_path = create_config_file(&temp, config_content);
|
||||
|
||||
let config = config::Config::builder()
|
||||
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
|
||||
.set_override_option("port", Some("9999".to_string()))
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
// CLI override should work
|
||||
assert_eq!(config.get_string("port").unwrap(), "9999");
|
||||
assert_eq!(config.get_int("port").unwrap(), 9999);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_missing_required_fields_fails() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
// Create config without required api-key
|
||||
let config_content = r#"
|
||||
model = "test-model"
|
||||
"#;
|
||||
|
||||
let config_path = create_config_file(&temp, config_content);
|
||||
let config = parse_config_from_file(&config_path).await;
|
||||
|
||||
// Should fail when trying to get required field
|
||||
assert!(config.get_string("api-key").is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_optional_instruct_field() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
let config_content = r#"
|
||||
instruct = "Custom bot instructions"
|
||||
"#;
|
||||
|
||||
let config_path = create_config_file(&temp, config_content);
|
||||
let config = parse_config_from_file(&config_path).await;
|
||||
|
||||
assert_eq!(
|
||||
config.get_string("instruct").unwrap(),
|
||||
"Custom bot instructions"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_command_path_field() {
|
||||
// command-path is in config.toml but not used anywhere in the code
|
||||
let temp = TempDir::new().unwrap();
|
||||
let config_content = r#"
|
||||
command-path = "/custom/commands"
|
||||
"#;
|
||||
|
||||
let config_path = create_config_file(&temp, config_content);
|
||||
let config = parse_config_from_file(&config_path).await;
|
||||
|
||||
assert_eq!(
|
||||
config.get_string("command-path").unwrap(),
|
||||
"/custom/commands"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_chroot_dir_field() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
let config_content = r#"
|
||||
chroot-dir = "/var/lib/bot/root"
|
||||
"#;
|
||||
|
||||
let config_path = create_config_file(&temp, config_content);
|
||||
let config = parse_config_from_file(&config_path).await;
|
||||
|
||||
assert_eq!(
|
||||
config.get_string("chroot-dir").unwrap(),
|
||||
"/var/lib/bot/root"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_empty_config_file() {
|
||||
let temp = TempDir::new().unwrap();
|
||||
let config_content = "";
|
||||
|
||||
let config_path = create_config_file(&temp, config_content);
|
||||
|
||||
// Should build successfully but have no values
|
||||
let config = parse_config_from_file(&config_path).await;
|
||||
|
||||
assert!(config.get_string("api-key").is_err());
|
||||
assert!(config.get_string("model").is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_all_cli_override_keys_match_config_format() {
|
||||
// This test documents which override keys in setup.rs match the config.toml format
|
||||
let temp = TempDir::new().unwrap();
|
||||
let config_content = "\
|
||||
api-key = \"test\"
|
||||
base-url = \"https://test.com\"
|
||||
chroot-dir = \"/test\"
|
||||
command-path = \"/cmds\"
|
||||
model = \"test-model\"
|
||||
instruct = \"test\"
|
||||
channels = [\"#test\"]
|
||||
server = \"test.server\"
|
||||
port = 6697
|
||||
nickname = \"test\"
|
||||
username = \"test\"
|
||||
use-tls = true
|
||||
";
|
||||
|
||||
let config_path = create_config_file(&temp, config_content);
|
||||
let config = parse_config_from_file(&config_path).await;
|
||||
|
||||
// All these should work with kebab-case (as in config.toml)
|
||||
assert!(config.get_string("api-key").is_ok());
|
||||
assert!(config.get_string("base-url").is_ok());
|
||||
assert!(config.get_string("chroot-dir").is_ok());
|
||||
assert!(config.get_string("command-path").is_ok());
|
||||
assert!(config.get_string("model").is_ok());
|
||||
assert!(config.get_string("instruct").is_ok());
|
||||
let channels_result: Result<Vec<String>, _> = config.get("channels");
|
||||
assert!(channels_result.is_ok());
|
||||
assert!(config.get_string("server").is_ok());
|
||||
assert!(config.get_int("port").is_ok());
|
||||
assert!(config.get_string("nickname").is_ok());
|
||||
assert!(config.get_string("username").is_ok());
|
||||
assert!(config.get_bool("use-tls").is_ok());
|
||||
}
|
||||
Reference in New Issue
Block a user