Compare commits
3 Commits
integratio
...
db292c2fd1
| Author | SHA1 | Date | |
|---|---|---|---|
| db292c2fd1 | |||
|
|
4e9428c376 | ||
|
|
5a084b5bf0 |
1153
Cargo.lock
generated
1153
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
43
Cargo.toml
43
Cargo.toml
@@ -4,50 +4,21 @@ version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
# TODO: make this a dev and/or debug dependency later.
|
||||
better-panic = "0.3.0"
|
||||
clap = { version = "4.5", features = [ "derive" ] }
|
||||
color-eyre = "0.6.3"
|
||||
config = { version = "0.15", features = [ "toml" ] }
|
||||
directories = "6.0"
|
||||
dotenvy_macro = "0.15"
|
||||
futures = "0.3"
|
||||
human-panic = "2.0"
|
||||
genai = "0.4.3"
|
||||
genai = "0.4.0-alpha.9"
|
||||
irc = "1.1"
|
||||
serde_json = "1.0"
|
||||
tokio = { version = "1", features = [ "full" ] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = "0.3"
|
||||
|
||||
[dependencies.nix]
|
||||
version = "0.30.1"
|
||||
features = [ "fs" ]
|
||||
|
||||
[dependencies.clap]
|
||||
version = "4.5"
|
||||
features = [ "derive" ]
|
||||
|
||||
[dependencies.config]
|
||||
version = "0.15"
|
||||
features = [ "toml" ]
|
||||
|
||||
[dependencies.serde]
|
||||
version = "1.0"
|
||||
features = [ "derive" ]
|
||||
|
||||
[dependencies.tokio]
|
||||
version = "1"
|
||||
features = [ "io-util", "macros", "net", "rt-multi-thread", "sync" ]
|
||||
|
||||
[dev-dependencies]
|
||||
rstest = "0.24"
|
||||
|
||||
[dev-dependencies.cargo-husky]
|
||||
version = "1"
|
||||
features = [
|
||||
"run-cargo-check",
|
||||
"run-cargo-clippy",
|
||||
]
|
||||
|
||||
[profile.release]
|
||||
strip = true
|
||||
opt-level = "z"
|
||||
lto = true
|
||||
codegen-units = 1
|
||||
panic = "abort"
|
||||
strip = true
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
This is an IRC bot that. The name is based on a fictional video game villain.
|
||||
Currently it supports any LLM that uses the OpenAI style of interface. They
|
||||
can be selected via command line options, environment variables, or via a configuration
|
||||
file. There is a [configuration file](config.toml) that *should* contain all available options
|
||||
file. There is a [configureation file](config.toml) that *should* contain all available options
|
||||
currently.
|
||||
|
||||
## Some supported but ~~possibly~~ *mostly* untested LLMs:
|
||||
|
||||
| Name | Model | Base URL | Tested |
|
||||
| Name | Model | Base URL | Teested |
|
||||
|------------|-------------------|-------------------------------------------|---------|
|
||||
| OpenAI | gpt-5 | https://api.openai.com/v1 | no |
|
||||
| Deepseek | deepseek-chat | https://api.deepseek.com/v1 | yes |
|
||||
|
||||
42
robotnik.1
42
robotnik.1
@@ -1,42 +0,0 @@
|
||||
.Dd $Mdocdate$
|
||||
.Dt robotnik 1
|
||||
.Os
|
||||
.Sh NAME
|
||||
.Nm robotnik
|
||||
.Nd A simple bot that among other things uses the OpenAI API.
|
||||
.\" .Sh LIBRARY
|
||||
.\" For sections 2, 3, and 9 only.
|
||||
.\" Not used in OpenBSD.
|
||||
.Sh SYNOPSIS
|
||||
.Nm progname
|
||||
.Op Fl options
|
||||
.Ar
|
||||
.Sh DESCRIPTION
|
||||
The
|
||||
.Nm
|
||||
utility processes files ...
|
||||
.\" .Sh CONTEXT
|
||||
.\" For section 9 functions only.
|
||||
.\" .Sh IMPLEMENTATION NOTES
|
||||
.\" Not used in OpenBSD.
|
||||
.\" .Sh RETURN VALUES
|
||||
.\" For sections 2, 3, and 9 function return values only.
|
||||
.\" .Sh ENVIRONMENT
|
||||
.\" For sections 1, 6, 7, and 8 only.
|
||||
.\" .Sh FILES
|
||||
.\" .Sh EXIT STATUS
|
||||
.\" For sections 1, 6, and 8 only.
|
||||
.\" .Sh EXAMPLES
|
||||
.\" .Sh DIAGNOSTICS
|
||||
.\" For sections 1, 4, 6, 7, 8, and 9 printf/stderr messages only.
|
||||
.\" .Sh ERRORS
|
||||
.\" For sections 2, 3, 4, and 9 errno settings only.
|
||||
.\" .Sh SEE ALSO
|
||||
.\" .Xr foobar 1
|
||||
.\" .Sh STANDARDS
|
||||
.\" .Sh HISTORY
|
||||
.\" .Sh AUTHORS
|
||||
.\" .Sh CAVEATS
|
||||
.\" .Sh BUGS
|
||||
.\" .Sh SECURITY CONSIDERATIONS
|
||||
.\" Not used in OpenBSD.
|
||||
@@ -3,5 +3,5 @@ style_edition = "2024"
|
||||
comment_width = 100
|
||||
format_code_in_doc_comments = true
|
||||
imports_granularity = "Crate"
|
||||
imports_layout = "HorizontalVertical"
|
||||
imports_layout = "Vertical"
|
||||
wrap_comments = true
|
||||
|
||||
51
src/chat.rs
51
src/chat.rs
@@ -1,10 +1,24 @@
|
||||
use color_eyre::{Result, eyre::WrapErr};
|
||||
use color_eyre::{
|
||||
Result,
|
||||
eyre::{
|
||||
OptionExt,
|
||||
WrapErr,
|
||||
},
|
||||
};
|
||||
// Lots of namespace confusion potential
|
||||
use crate::qna::LLMHandle;
|
||||
use config::Config as MainConfig;
|
||||
use futures::StreamExt;
|
||||
use irc::client::prelude::{Client as IRCClient, Command, Config as IRCConfig};
|
||||
use tracing::{Level, event, instrument};
|
||||
use irc::client::prelude::{
|
||||
Client as IRCClient,
|
||||
Command,
|
||||
Config as IRCConfig,
|
||||
};
|
||||
use tracing::{
|
||||
Level,
|
||||
event,
|
||||
instrument,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Chat {
|
||||
@@ -18,7 +32,8 @@ pub async fn new(settings: &MainConfig, handle: &LLMHandle) -> Result<Chat> {
|
||||
// Going to just assign and let the irc library handle errors for now, and
|
||||
// add my own checking if necessary.
|
||||
let port: u16 = settings.get("port")?;
|
||||
let channels: Vec<String> = settings.get("channels").wrap_err("No channels provided.")?;
|
||||
let channels: Vec<String> = settings.get("channels")
|
||||
.wrap_err("No channels provided.")?;
|
||||
|
||||
event!(Level::INFO, "Channels = {:?}", channels);
|
||||
|
||||
@@ -46,22 +61,26 @@ impl Chat {
|
||||
|
||||
client.identify()?;
|
||||
|
||||
let outgoing = client
|
||||
.outgoing()
|
||||
.ok_or_eyre("Couldn't get outgoing irc sink.")?;
|
||||
let mut stream = client.stream()?;
|
||||
|
||||
while let Some(message) = stream.next().await.transpose()? {
|
||||
if let Command::PRIVMSG(channel, message) = message.command
|
||||
&& message.starts_with("!gem")
|
||||
{
|
||||
let mut msg = self.llm_handle.send_request(&message).await?;
|
||||
event!(Level::INFO, "Asked: {}", message);
|
||||
event!(Level::INFO, "Answered: {}", msg);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = outgoing.await {
|
||||
event!(Level::ERROR, "Failed to drive output: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
// Make it all one line.
|
||||
msg.retain(|c| c != '\n' && c != '\r');
|
||||
msg.truncate(500);
|
||||
while let Some(message) = stream.next().await.transpose()? {
|
||||
if let Command::PRIVMSG(channel, message) = message.command {
|
||||
if message.starts_with("!gem") {
|
||||
let msg = self.llm_handle.send_request(message).await?;
|
||||
event!(Level::INFO, "Message received.");
|
||||
client
|
||||
.send_privmsg(&channel, msg)
|
||||
.wrap_err("Could not send to {channel}")?;
|
||||
.send_privmsg(channel, msg)
|
||||
.wrap_err("Couldn't send response to channel.")?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
14
src/event.rs
14
src/event.rs
@@ -1,14 +0,0 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub struct Event {
|
||||
message: String,
|
||||
}
|
||||
|
||||
impl Event {
|
||||
pub fn new(msg: impl Into<String>) -> Self {
|
||||
Self {
|
||||
message: msg.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,293 +0,0 @@
|
||||
use std::{collections::VecDeque, path::Path, sync::Arc};
|
||||
|
||||
use color_eyre::Result;
|
||||
//use nix::{NixPath, sys::stat, unistd::mkfifo};
|
||||
use tokio::{
|
||||
// fs::File,
|
||||
io::AsyncWriteExt,
|
||||
net::{
|
||||
UnixListener,
|
||||
UnixStream,
|
||||
// unix::pipe::{self, Receiver},
|
||||
},
|
||||
sync::{RwLock, broadcast},
|
||||
};
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::event::Event;
|
||||
|
||||
// Hard coding for now. Maybe make this a parameter to new.
|
||||
const EVENT_BUF_MAX: usize = 1000;
|
||||
|
||||
// Manager for communication with plugins.
|
||||
pub struct EventManager {
|
||||
announce: broadcast::Sender<String>, // Everything broadcasts here.
|
||||
events: Arc<RwLock<VecDeque<String>>>, // Ring buffer.
|
||||
}
|
||||
|
||||
impl EventManager {
|
||||
pub fn new() -> Result<Self> {
|
||||
let (announce, _) = broadcast::channel(100);
|
||||
|
||||
Ok(Self {
|
||||
announce,
|
||||
events: Arc::new(RwLock::new(VecDeque::<String>::new())),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn broadcast(&self, event: &Event) -> Result<()> {
|
||||
let msg = serde_json::to_string(event)? + "\n";
|
||||
|
||||
let mut events = self.events.write().await;
|
||||
|
||||
if events.len() >= EVENT_BUF_MAX {
|
||||
events.pop_front();
|
||||
}
|
||||
|
||||
events.push_back(msg.clone());
|
||||
drop(events);
|
||||
|
||||
let _ = self.announce.send(msg);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// NB: This assumes it has exclusive control of the FIFO.
|
||||
// async fn start_fifo<P>(path: &P) -> Result<()>
|
||||
// where
|
||||
// P: AsRef<Path> + NixPath + ?Sized,
|
||||
// {
|
||||
// // Just delete the old FIFO if it exists.
|
||||
// let _ = std::fs::remove_file(path);
|
||||
// mkfifo(path, stat::Mode::S_IRWXU)?;
|
||||
|
||||
// Ok(())
|
||||
// }
|
||||
|
||||
pub async fn start_listening(self: Arc<Self>, broadcast_path: impl AsRef<Path>) {
|
||||
let listener = UnixListener::bind(broadcast_path).unwrap();
|
||||
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((stream, _addr)) => {
|
||||
info!("New broadcast subscriber");
|
||||
// Spawn a new stream for the plugin. The loop
|
||||
// runs recursively from there.
|
||||
let broadcaster = Arc::clone(&self);
|
||||
tokio::spawn(async move {
|
||||
// send events.
|
||||
let _ = broadcaster.send_events(stream).await;
|
||||
});
|
||||
}
|
||||
Err(e) => error!("Accept error: {e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_events(&self, stream: UnixStream) -> Result<()> {
|
||||
let mut writer = stream;
|
||||
|
||||
// Take care of history.
|
||||
let events = self.events.read().await;
|
||||
for event in events.iter() {
|
||||
writer.write_all(event.as_bytes()).await?;
|
||||
}
|
||||
drop(events);
|
||||
|
||||
// Now just broadcast the new events.
|
||||
let mut rx = self.announce.subscribe();
|
||||
while let Ok(event) = rx.recv().await {
|
||||
if writer.write_all(event.as_bytes()).await.is_err() {
|
||||
// *click*
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rstest::rstest;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_new_event_manager_has_empty_buffer() {
|
||||
let manager = EventManager::new().unwrap();
|
||||
let events = manager.events.read().await;
|
||||
assert_eq!(events.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_broadcast_adds_event_to_buffer() {
|
||||
let manager = EventManager::new().unwrap();
|
||||
let event = Event::new("test message");
|
||||
|
||||
manager.broadcast(&event).await.unwrap();
|
||||
|
||||
let events = manager.events.read().await;
|
||||
assert_eq!(events.len(), 1);
|
||||
assert!(events[0].contains("test message"));
|
||||
assert!(events[0].ends_with('\n'));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_broadcast_serializes_event_as_json() {
|
||||
let manager = EventManager::new().unwrap();
|
||||
let event = Event::new("hello world");
|
||||
|
||||
manager.broadcast(&event).await.unwrap();
|
||||
|
||||
let events = manager.events.read().await;
|
||||
let stored = &events[0];
|
||||
|
||||
// Should be valid JSON
|
||||
let parsed: serde_json::Value = serde_json::from_str(stored.trim()).unwrap();
|
||||
assert_eq!(parsed["message"], "hello world");
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case(1)]
|
||||
#[case(10)]
|
||||
#[case(100)]
|
||||
#[case(999)]
|
||||
#[tokio::test]
|
||||
async fn test_buffer_holds_events_below_max(#[case] count: usize) {
|
||||
let manager = EventManager::new().unwrap();
|
||||
|
||||
for i in 0..count {
|
||||
let event = Event::new(format!("event {}", i));
|
||||
manager.broadcast(&event).await.unwrap();
|
||||
}
|
||||
|
||||
let events = manager.events.read().await;
|
||||
assert_eq!(events.len(), count);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_buffer_at_exactly_max_capacity() {
|
||||
let manager = EventManager::new().unwrap();
|
||||
|
||||
// Fill to exactly EVENT_BUF_MAX (1000)
|
||||
for i in 0..EVENT_BUF_MAX {
|
||||
let event = Event::new(format!("event {}", i));
|
||||
manager.broadcast(&event).await.unwrap();
|
||||
}
|
||||
|
||||
let events = manager.events.read().await;
|
||||
assert_eq!(events.len(), EVENT_BUF_MAX);
|
||||
assert!(events[0].contains("event 0"));
|
||||
assert!(events[EVENT_BUF_MAX - 1].contains("event 999"));
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case(1)]
|
||||
#[case(10)]
|
||||
#[case(100)]
|
||||
#[case(500)]
|
||||
#[tokio::test]
|
||||
async fn test_buffer_overflow_evicts_oldest_fifo(#[case] overflow: usize) {
|
||||
let manager = EventManager::new().unwrap();
|
||||
let total = EVENT_BUF_MAX + overflow;
|
||||
|
||||
// Broadcast more events than buffer can hold
|
||||
for i in 0..total {
|
||||
let event = Event::new(format!("event {}", i));
|
||||
manager.broadcast(&event).await.unwrap();
|
||||
}
|
||||
|
||||
let events = manager.events.read().await;
|
||||
|
||||
// Buffer should still be at max capacity
|
||||
assert_eq!(events.len(), EVENT_BUF_MAX);
|
||||
|
||||
// Oldest events (0 through overflow-1) should be evicted
|
||||
// Buffer should contain events [overflow..total)
|
||||
let first_event = &events[0];
|
||||
let last_event = &events[EVENT_BUF_MAX - 1];
|
||||
|
||||
assert!(first_event.contains(&format!("event {}", overflow)));
|
||||
assert!(last_event.contains(&format!("event {}", total - 1)));
|
||||
|
||||
// Verify the evicted events are NOT in the buffer
|
||||
let buffer_string = events.iter().cloned().collect::<String>();
|
||||
assert!(!buffer_string.contains(r#""message":"event 0""#));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_broadcasts_maintain_order() {
|
||||
let manager = EventManager::new().unwrap();
|
||||
let messages = vec!["first", "second", "third", "fourth", "fifth"];
|
||||
|
||||
for msg in &messages {
|
||||
let event = Event::new(*msg);
|
||||
manager.broadcast(&event).await.unwrap();
|
||||
}
|
||||
|
||||
let events = manager.events.read().await;
|
||||
assert_eq!(events.len(), messages.len());
|
||||
|
||||
for (i, expected) in messages.iter().enumerate() {
|
||||
assert!(events[i].contains(expected));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_buffer_wraparound_maintains_newest_events() {
|
||||
let manager = EventManager::new().unwrap();
|
||||
|
||||
// Fill buffer completely
|
||||
for i in 0..EVENT_BUF_MAX {
|
||||
let event = Event::new(format!("old {}", i));
|
||||
manager.broadcast(&event).await.unwrap();
|
||||
}
|
||||
|
||||
// Add 5 more events
|
||||
for i in 0..5 {
|
||||
let event = Event::new(format!("new {}", i));
|
||||
manager.broadcast(&event).await.unwrap();
|
||||
}
|
||||
|
||||
let events = manager.events.read().await;
|
||||
assert_eq!(events.len(), EVENT_BUF_MAX);
|
||||
|
||||
// First 5 old events should be gone
|
||||
let buffer_string = events.iter().cloned().collect::<String>();
|
||||
assert!(!buffer_string.contains(r#""message":"old 0""#));
|
||||
assert!(!buffer_string.contains(r#""message":"old 4""#));
|
||||
|
||||
// But old 5 should still be there (now at the front)
|
||||
assert!(events[0].contains("old 5"));
|
||||
|
||||
// New events should be at the end
|
||||
assert!(events[EVENT_BUF_MAX - 5].contains("new 0"));
|
||||
assert!(events[EVENT_BUF_MAX - 1].contains("new 4"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_broadcasts_all_stored() {
|
||||
let manager = Arc::new(EventManager::new().unwrap());
|
||||
let mut handles = vec![];
|
||||
|
||||
// Spawn 10 concurrent tasks, each broadcasting 10 events
|
||||
for task_id in 0..10 {
|
||||
let manager_clone = Arc::clone(&manager);
|
||||
let handle = tokio::spawn(async move {
|
||||
for i in 0..10 {
|
||||
let event = Event::new(format!("task {} event {}", task_id, i));
|
||||
manager_clone.broadcast(&event).await.unwrap();
|
||||
}
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all tasks to complete
|
||||
for handle in handles {
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
let events = manager.events.read().await;
|
||||
assert_eq!(events.len(), 100);
|
||||
}
|
||||
}
|
||||
26
src/ipc.rs
26
src/ipc.rs
@@ -1,26 +0,0 @@
|
||||
// Provides an IPC socket to communicate with other processes.
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use color_eyre::Result;
|
||||
use tokio::net::UnixListener;
|
||||
|
||||
pub struct IPC {
|
||||
listener: UnixListener,
|
||||
}
|
||||
|
||||
impl IPC {
|
||||
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let listener = UnixListener::bind(path)?;
|
||||
Ok(Self { listener })
|
||||
}
|
||||
|
||||
pub async fn run(&self) -> Result<()> {
|
||||
loop {
|
||||
match self.listener.accept().await {
|
||||
Ok((_stream, _addr)) => {}
|
||||
Err(e) => return Err(e.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
85
src/lib.rs
85
src/lib.rs
@@ -1,85 +0,0 @@
|
||||
// Robotnik libraries
|
||||
|
||||
use std::{os::unix::fs, sync::Arc};
|
||||
|
||||
use color_eyre::{Result, eyre::WrapErr};
|
||||
use human_panic::setup_panic;
|
||||
use tracing::{Level, info};
|
||||
use tracing_subscriber::FmtSubscriber;
|
||||
|
||||
pub mod chat;
|
||||
pub mod event;
|
||||
pub mod event_manager;
|
||||
pub mod ipc;
|
||||
pub mod qna;
|
||||
pub mod setup;
|
||||
|
||||
pub use event::Event;
|
||||
pub use event_manager::EventManager;
|
||||
pub use qna::LLMHandle;
|
||||
|
||||
const DEFAULT_INSTRUCT: &str =
|
||||
"You are a shady, yet helpful IRC bot. You try to give responses that can
|
||||
be sent in a single IRC response according to the specification. Keep answers to
|
||||
500 characters or less.";
|
||||
|
||||
// NB: Everything should fail if logging doesn't start properly.
|
||||
async fn init_logging() {
|
||||
better_panic::install();
|
||||
setup_panic!();
|
||||
|
||||
let subscriber = FmtSubscriber::builder()
|
||||
.with_max_level(Level::TRACE)
|
||||
.finish();
|
||||
|
||||
tracing::subscriber::set_global_default(subscriber).unwrap();
|
||||
}
|
||||
|
||||
pub async fn run() -> Result<()> {
|
||||
init_logging().await;
|
||||
info!("Starting up.");
|
||||
|
||||
let settings = setup::init().await.wrap_err("Failed to initialize.")?;
|
||||
let config = settings.config;
|
||||
|
||||
// NOTE: Doing chroot this way might be impractical.
|
||||
if let Ok(chroot_path) = config.get_string("chroot-dir") {
|
||||
info!("Attempting to chroot to {}", chroot_path);
|
||||
fs::chroot(&chroot_path)
|
||||
.wrap_err_with(|| format!("Failed setting chroot '{}'", chroot_path))?;
|
||||
std::env::set_current_dir("/").wrap_err("Couldn't change directory after chroot.")?;
|
||||
}
|
||||
|
||||
let handle = qna::LLMHandle::new(
|
||||
config.get_string("api-key").wrap_err("API missing.")?,
|
||||
config
|
||||
.get_string("base-url")
|
||||
.wrap_err("base-url missing.")?,
|
||||
config
|
||||
.get_string("model")
|
||||
.wrap_err("model string missing.")?,
|
||||
config
|
||||
.get_string("instruct")
|
||||
.unwrap_or_else(|_| DEFAULT_INSTRUCT.to_string()),
|
||||
)
|
||||
.wrap_err("Couldn't initialize LLM handle.")?;
|
||||
|
||||
let ev_manager = Arc::new(EventManager::new()?);
|
||||
let ev_manager_clone = Arc::clone(&ev_manager);
|
||||
ev_manager_clone
|
||||
.broadcast(&Event::new("Starting..."))
|
||||
.await?;
|
||||
|
||||
let mut c = chat::new(&config, &handle).await?;
|
||||
|
||||
tokio::select! {
|
||||
_ = ev_manager_clone.start_listening("/tmp/robo.sock") => {
|
||||
// Event listener ended
|
||||
}
|
||||
result = c.run() => {
|
||||
result.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
63
src/main.rs
63
src/main.rs
@@ -1,6 +1,65 @@
|
||||
use color_eyre::Result;
|
||||
use color_eyre::{
|
||||
Result,
|
||||
eyre::WrapErr,
|
||||
};
|
||||
use human_panic::setup_panic;
|
||||
use std::os::unix::fs;
|
||||
use tracing::{
|
||||
Level,
|
||||
info,
|
||||
};
|
||||
use tracing_subscriber::FmtSubscriber;
|
||||
|
||||
mod chat;
|
||||
mod commands;
|
||||
mod qna;
|
||||
mod setup;
|
||||
|
||||
const DEFAULT_INSTRUCT: &str =
|
||||
"You are a shady, yet helpful IRC bot. You try to give responses that can
|
||||
be sent in a single IRC response according to the specification.";
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
robotnik::run().await
|
||||
// Some error sprucing.
|
||||
better_panic::install();
|
||||
setup_panic!();
|
||||
|
||||
let subscriber = FmtSubscriber::builder()
|
||||
.with_max_level(Level::TRACE)
|
||||
.finish();
|
||||
|
||||
tracing::subscriber::set_global_default(subscriber)
|
||||
.wrap_err("Failed to setup trace logging.")?;
|
||||
|
||||
info!("Starting");
|
||||
|
||||
let settings = setup::init().await.wrap_err("Failed to initialize.")?;
|
||||
let config = settings.config;
|
||||
|
||||
// chroot if applicable.
|
||||
if let Ok(chroot_path) = config.get_string("chroot-dir") {
|
||||
fs::chroot(&chroot_path)
|
||||
.wrap_err_with(|| format!("Failed setting chroot '{}'", chroot_path.to_string()))?;
|
||||
std::env::set_current_dir("/").wrap_err("Couldn't change directory after chroot.")?;
|
||||
}
|
||||
|
||||
let handle = qna::new(
|
||||
config.get_string("api-key").wrap_err("API missing.")?,
|
||||
config
|
||||
.get_string("base-url")
|
||||
.wrap_err("base-url missing.")?,
|
||||
config
|
||||
.get_string("model")
|
||||
.wrap_err("model string missing.")?,
|
||||
config
|
||||
.get_string("instruct")
|
||||
.unwrap_or_else(|_| DEFAULT_INSTRUCT.to_string()),
|
||||
)
|
||||
.wrap_err("Couldn't initialize LLM handle.")?;
|
||||
let mut c = chat::new(&config, &handle).await?;
|
||||
|
||||
c.run().await.unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
14
src/qna.rs
14
src/qna.rs
@@ -3,8 +3,16 @@ use futures::StreamExt;
|
||||
use genai::{
|
||||
Client,
|
||||
ModelIden,
|
||||
chat::{ChatMessage, ChatRequest, ChatStreamEvent, StreamChunk},
|
||||
resolver::{AuthData, AuthResolver},
|
||||
chat::{
|
||||
ChatMessage,
|
||||
ChatRequest,
|
||||
ChatStreamEvent,
|
||||
StreamChunk,
|
||||
},
|
||||
resolver::{
|
||||
AuthData,
|
||||
AuthResolver,
|
||||
},
|
||||
};
|
||||
use tracing::info;
|
||||
|
||||
@@ -17,7 +25,6 @@ pub struct LLMHandle {
|
||||
model: String,
|
||||
}
|
||||
|
||||
impl LLMHandle {
|
||||
pub fn new(
|
||||
api_key: String,
|
||||
_base_url: impl AsRef<str>,
|
||||
@@ -44,6 +51,7 @@ impl LLMHandle {
|
||||
})
|
||||
}
|
||||
|
||||
impl LLMHandle {
|
||||
pub async fn send_request(&mut self, message: impl Into<String>) -> Result<String> {
|
||||
let mut req = self.chat_request.clone();
|
||||
let client = self.client.clone();
|
||||
|
||||
38
src/setup.rs
38
src/setup.rs
@@ -8,65 +8,65 @@ use tracing::{info, instrument};
|
||||
// TODO: use [clap(long, short, help_heading = Some(section))]
|
||||
#[derive(Clone, Debug, Parser)]
|
||||
#[command(about, version)]
|
||||
pub struct Args {
|
||||
pub(crate) struct Args {
|
||||
#[arg(short, long)]
|
||||
/// API Key for the LLM in use.
|
||||
pub api_key: Option<String>,
|
||||
pub(crate) api_key: Option<String>,
|
||||
|
||||
#[arg(short, long, default_value = "https://api.openai.com")]
|
||||
/// Base URL for the LLM API to use.
|
||||
pub base_url: Option<String>,
|
||||
pub(crate) base_url: Option<String>,
|
||||
|
||||
/// Directory to use for chroot (recommended).
|
||||
#[arg(long)]
|
||||
pub chroot_dir: Option<String>,
|
||||
pub(crate) chroot_dir: Option<String>,
|
||||
|
||||
/// Root directory for file based command structure.
|
||||
#[arg(long)]
|
||||
pub command_dir: Option<String>,
|
||||
pub(crate) command_dir: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
/// Instructions to the model on how to behave.
|
||||
pub instruct: Option<String>,
|
||||
pub(crate) intruct: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
pub model: Option<String>,
|
||||
pub(crate) model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
#[arg(long = "channel")]
|
||||
/// List of IRC channels to join.
|
||||
pub channels: Option<Vec<String>>,
|
||||
pub(crate) channels: Option<Vec<String>>,
|
||||
|
||||
#[arg(short, long)]
|
||||
/// Custom configuration file location if need be.
|
||||
pub config_file: Option<PathBuf>,
|
||||
pub(crate) config_file: Option<PathBuf>,
|
||||
|
||||
#[arg(short, long, default_value = "irc.libera.chat")]
|
||||
/// IRC server.
|
||||
pub server: Option<String>,
|
||||
pub(crate) server: Option<String>,
|
||||
|
||||
#[arg(short, long, default_value = "6697")]
|
||||
/// Port of the IRC server.
|
||||
pub port: Option<String>,
|
||||
pub(crate) port: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
/// IRC Nickname.
|
||||
pub nickname: Option<String>,
|
||||
pub(crate) nickname: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
/// IRC Nick Password
|
||||
pub nick_password: Option<String>,
|
||||
pub(crate) nick_password: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
/// IRC Username
|
||||
pub username: Option<String>,
|
||||
pub(crate) username: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
/// Whether or not to use TLS when connecting to the IRC server.
|
||||
pub use_tls: Option<bool>,
|
||||
pub(crate) use_tls: Option<bool>,
|
||||
}
|
||||
|
||||
pub struct Setup {
|
||||
pub config: Config,
|
||||
pub(crate) struct Setup {
|
||||
pub(crate) config: Config,
|
||||
}
|
||||
|
||||
#[instrument]
|
||||
@@ -98,7 +98,7 @@ pub async fn init() -> Result<Setup> {
|
||||
.set_override_option("chroot-dir", args.chroot_dir.clone())?
|
||||
.set_override_option("command-path", args.command_dir.clone())?
|
||||
.set_override_option("model", args.model.clone())?
|
||||
.set_override_option("instruct", args.instruct.clone())?
|
||||
.set_override_option("instruct", args.model.clone())?
|
||||
.set_override_option("channels", args.channels.clone())?
|
||||
.set_override_option("server", args.server.clone())?
|
||||
.set_override_option("port", args.port.clone())? // FIXME: Make this a default here not in clap.
|
||||
|
||||
@@ -1,492 +0,0 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use robotnik::{event::Event, event_manager::EventManager};
|
||||
use rstest::rstest;
|
||||
use tokio::{
|
||||
io::{AsyncBufReadExt, BufReader},
|
||||
net::UnixStream,
|
||||
time::timeout,
|
||||
};
|
||||
|
||||
const TEST_SOCKET_BASE: &str = "/tmp/robotnik_test";
|
||||
|
||||
/// Helper to create unique socket paths for parallel tests
|
||||
fn test_socket_path(name: &str) -> String {
|
||||
format!("{}_{}_{}", TEST_SOCKET_BASE, name, std::process::id())
|
||||
}
|
||||
|
||||
/// Helper to read one JSON event from a stream
|
||||
async fn read_event(
|
||||
reader: &mut BufReader<UnixStream>,
|
||||
) -> Result<Event, Box<dyn std::error::Error>> {
|
||||
let mut line = String::new();
|
||||
reader.read_line(&mut line).await?;
|
||||
let event: Event = serde_json::from_str(&line)?;
|
||||
Ok(event)
|
||||
}
|
||||
|
||||
/// Helper to read all available events with a timeout
|
||||
async fn read_events_with_timeout(
|
||||
reader: &mut BufReader<UnixStream>,
|
||||
max_count: usize,
|
||||
timeout_ms: u64,
|
||||
) -> Vec<String> {
|
||||
let mut events = Vec::new();
|
||||
for _ in 0..max_count {
|
||||
let mut line = String::new();
|
||||
match timeout(
|
||||
Duration::from_millis(timeout_ms),
|
||||
reader.read_line(&mut line),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(0)) => break, // EOF
|
||||
Ok(Ok(_)) => events.push(line),
|
||||
Ok(Err(_)) => break, // Read error
|
||||
Err(_) => break, // Timeout
|
||||
}
|
||||
}
|
||||
events
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_client_connects_and_receives_event() {
|
||||
let socket_path = test_socket_path("basic_connect");
|
||||
let manager = Arc::new(EventManager::new().unwrap());
|
||||
|
||||
// Start the listener
|
||||
let listener_manager = Arc::clone(&manager);
|
||||
let socket_path_clone = socket_path.clone();
|
||||
tokio::spawn(async move {
|
||||
listener_manager.start_listening(socket_path_clone).await;
|
||||
});
|
||||
|
||||
// Give the listener time to start
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
|
||||
// Broadcast an event
|
||||
let event = Event::new("test message");
|
||||
manager.broadcast(&event).await.unwrap();
|
||||
|
||||
// Connect as a client
|
||||
let stream = UnixStream::connect(&socket_path).await.unwrap();
|
||||
let mut reader = BufReader::new(stream);
|
||||
|
||||
// Read the event
|
||||
let mut line = String::new();
|
||||
reader.read_line(&mut line).await.unwrap();
|
||||
|
||||
assert!(line.contains("test message"));
|
||||
assert!(line.ends_with('\n'));
|
||||
|
||||
// Cleanup
|
||||
let _ = std::fs::remove_file(&socket_path);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_client_receives_event_history() {
|
||||
let socket_path = test_socket_path("event_history");
|
||||
let manager = Arc::new(EventManager::new().unwrap());
|
||||
|
||||
// Broadcast events BEFORE starting the listener
|
||||
for i in 0..5 {
|
||||
let event = Event::new(format!("historical event {}", i));
|
||||
manager.broadcast(&event).await.unwrap();
|
||||
}
|
||||
|
||||
// Start the listener
|
||||
let listener_manager = Arc::clone(&manager);
|
||||
let socket_path_clone = socket_path.clone();
|
||||
tokio::spawn(async move {
|
||||
listener_manager.start_listening(socket_path_clone).await;
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
|
||||
// Connect as a client
|
||||
let stream = UnixStream::connect(&socket_path).await.unwrap();
|
||||
let mut reader = BufReader::new(stream);
|
||||
|
||||
// Should receive all 5 historical events
|
||||
let events = read_events_with_timeout(&mut reader, 5, 100).await;
|
||||
|
||||
assert_eq!(events.len(), 5);
|
||||
assert!(events[0].contains("historical event 0"));
|
||||
assert!(events[4].contains("historical event 4"));
|
||||
|
||||
// Cleanup
|
||||
let _ = std::fs::remove_file(&socket_path);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_clients_receive_same_events() {
|
||||
let socket_path = test_socket_path("multiple_clients");
|
||||
let manager = Arc::new(EventManager::new().unwrap());
|
||||
|
||||
// Start the listener
|
||||
let listener_manager = Arc::clone(&manager);
|
||||
let socket_path_clone = socket_path.clone();
|
||||
tokio::spawn(async move {
|
||||
listener_manager.start_listening(socket_path_clone).await;
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
|
||||
// Connect 3 clients
|
||||
let stream1 = UnixStream::connect(&socket_path).await.unwrap();
|
||||
let stream2 = UnixStream::connect(&socket_path).await.unwrap();
|
||||
let stream3 = UnixStream::connect(&socket_path).await.unwrap();
|
||||
|
||||
let mut reader1 = BufReader::new(stream1);
|
||||
let mut reader2 = BufReader::new(stream2);
|
||||
let mut reader3 = BufReader::new(stream3);
|
||||
|
||||
// Broadcast a new event
|
||||
let event = Event::new("broadcast to all");
|
||||
manager.broadcast(&event).await.unwrap();
|
||||
|
||||
// All clients should receive the event
|
||||
let mut line1 = String::new();
|
||||
let mut line2 = String::new();
|
||||
let mut line3 = String::new();
|
||||
|
||||
timeout(Duration::from_millis(100), reader1.read_line(&mut line1))
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
timeout(Duration::from_millis(100), reader2.read_line(&mut line2))
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
timeout(Duration::from_millis(100), reader3.read_line(&mut line3))
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
assert!(line1.contains("broadcast to all"));
|
||||
assert!(line2.contains("broadcast to all"));
|
||||
assert!(line3.contains("broadcast to all"));
|
||||
|
||||
// Cleanup
|
||||
let _ = std::fs::remove_file(&socket_path);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_late_joiner_receives_full_history() {
|
||||
let socket_path = test_socket_path("late_joiner");
|
||||
let manager = Arc::new(EventManager::new().unwrap());
|
||||
|
||||
// Start the listener
|
||||
let listener_manager = Arc::clone(&manager);
|
||||
let socket_path_clone = socket_path.clone();
|
||||
tokio::spawn(async move {
|
||||
listener_manager.start_listening(socket_path_clone).await;
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
|
||||
// First client connects
|
||||
let stream1 = UnixStream::connect(&socket_path).await.unwrap();
|
||||
let mut reader1 = BufReader::new(stream1);
|
||||
|
||||
// Broadcast several events
|
||||
for i in 0..10 {
|
||||
let event = Event::new(format!("event {}", i));
|
||||
manager.broadcast(&event).await.unwrap();
|
||||
}
|
||||
|
||||
// Consume events from first client
|
||||
let _ = read_events_with_timeout(&mut reader1, 10, 100).await;
|
||||
|
||||
// Late joiner connects
|
||||
let stream2 = UnixStream::connect(&socket_path).await.unwrap();
|
||||
let mut reader2 = BufReader::new(stream2);
|
||||
|
||||
// Late joiner should receive all 10 events from history
|
||||
let events = read_events_with_timeout(&mut reader2, 10, 100).await;
|
||||
|
||||
assert_eq!(events.len(), 10);
|
||||
assert!(events[0].contains("event 0"));
|
||||
assert!(events[9].contains("event 9"));
|
||||
|
||||
// Cleanup
|
||||
let _ = std::fs::remove_file(&socket_path);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_client_receives_events_in_order() {
|
||||
let socket_path = test_socket_path("event_order");
|
||||
let manager = Arc::new(EventManager::new().unwrap());
|
||||
|
||||
// Start the listener
|
||||
let listener_manager = Arc::clone(&manager);
|
||||
let socket_path_clone = socket_path.clone();
|
||||
tokio::spawn(async move {
|
||||
listener_manager.start_listening(socket_path_clone).await;
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
|
||||
// Connect client
|
||||
let stream = UnixStream::connect(&socket_path).await.unwrap();
|
||||
let mut reader = BufReader::new(stream);
|
||||
|
||||
// Broadcast events rapidly
|
||||
let count = 50;
|
||||
for i in 0..count {
|
||||
let event = Event::new(format!("sequence {}", i));
|
||||
manager.broadcast(&event).await.unwrap();
|
||||
}
|
||||
|
||||
// Read all events
|
||||
let events = read_events_with_timeout(&mut reader, count, 500).await;
|
||||
|
||||
assert_eq!(events.len(), count);
|
||||
|
||||
// Verify order
|
||||
for (i, event) in events.iter().enumerate() {
|
||||
assert!(
|
||||
event.contains(&format!("sequence {}", i)),
|
||||
"Event {} out of order: {}",
|
||||
i,
|
||||
event
|
||||
);
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
let _ = std::fs::remove_file(&socket_path);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_broadcasts_during_client_connections() {
|
||||
let socket_path = test_socket_path("concurrent_ops");
|
||||
let manager = Arc::new(EventManager::new().unwrap());
|
||||
|
||||
// Start the listener
|
||||
let listener_manager = Arc::clone(&manager);
|
||||
let socket_path_clone = socket_path.clone();
|
||||
tokio::spawn(async move {
|
||||
listener_manager.start_listening(socket_path_clone).await;
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
|
||||
// Connect client 1 BEFORE any broadcasts
|
||||
let stream1 = UnixStream::connect(&socket_path).await.unwrap();
|
||||
let mut reader1 = BufReader::new(stream1);
|
||||
|
||||
// Spawn a task that continuously broadcasts
|
||||
let broadcast_manager = Arc::clone(&manager);
|
||||
let broadcast_handle = tokio::spawn(async move {
|
||||
for i in 0..100 {
|
||||
let event = Event::new(format!("concurrent event {}", i));
|
||||
broadcast_manager.broadcast(&event).await.unwrap();
|
||||
tokio::time::sleep(Duration::from_millis(5)).await;
|
||||
}
|
||||
});
|
||||
|
||||
// While broadcasting, connect more clients at different times
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
let stream2 = UnixStream::connect(&socket_path).await.unwrap();
|
||||
let mut reader2 = BufReader::new(stream2);
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(150)).await;
|
||||
let stream3 = UnixStream::connect(&socket_path).await.unwrap();
|
||||
let mut reader3 = BufReader::new(stream3);
|
||||
|
||||
// Wait for broadcasts to complete
|
||||
broadcast_handle.await.unwrap();
|
||||
|
||||
// All clients should have received events
|
||||
let events1 = read_events_with_timeout(&mut reader1, 100, 200).await;
|
||||
let events2 = read_events_with_timeout(&mut reader2, 100, 200).await;
|
||||
let events3 = read_events_with_timeout(&mut reader3, 100, 200).await;
|
||||
|
||||
// Client 1 connected first (before any broadcasts), should get all 100
|
||||
assert_eq!(events1.len(), 100);
|
||||
|
||||
// Client 2 connected after ~20 events were broadcast
|
||||
// Gets ~20 from history + ~80 live = 100
|
||||
assert_eq!(events2.len(), 100);
|
||||
|
||||
// Client 3 connected after ~50 events were broadcast
|
||||
// Gets ~50 from history + ~50 live = 100
|
||||
assert_eq!(events3.len(), 100);
|
||||
|
||||
// Verify they all received events in order
|
||||
assert!(events1[0].contains("concurrent event 0"));
|
||||
assert!(events2[0].contains("concurrent event 0"));
|
||||
assert!(events3[0].contains("concurrent event 0"));
|
||||
|
||||
// Cleanup
|
||||
let _ = std::fs::remove_file(&socket_path);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_buffer_overflow_affects_new_clients() {
|
||||
let socket_path = test_socket_path("buffer_overflow");
|
||||
let manager = Arc::new(EventManager::new().unwrap());
|
||||
|
||||
// Broadcast more than buffer max (1000)
|
||||
for i in 0..1100 {
|
||||
let event = Event::new(format!("overflow event {}", i));
|
||||
manager.broadcast(&event).await.unwrap();
|
||||
}
|
||||
|
||||
// Start the listener
|
||||
let listener_manager = Arc::clone(&manager);
|
||||
let socket_path_clone = socket_path.clone();
|
||||
tokio::spawn(async move {
|
||||
listener_manager.start_listening(socket_path_clone).await;
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
|
||||
// New client connects
|
||||
let stream = UnixStream::connect(&socket_path).await.unwrap();
|
||||
let mut reader = BufReader::new(stream);
|
||||
|
||||
// Should receive exactly 1000 events (buffer max)
|
||||
let events = read_events_with_timeout(&mut reader, 1100, 500).await;
|
||||
|
||||
assert_eq!(events.len(), 1000);
|
||||
|
||||
// First event should be 100 (oldest 100 were evicted)
|
||||
assert!(events[0].contains("overflow event 100"));
|
||||
|
||||
// Last event should be 1099
|
||||
assert!(events[999].contains("overflow event 1099"));
|
||||
|
||||
// Cleanup
|
||||
let _ = std::fs::remove_file(&socket_path);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case(10, 1)]
|
||||
#[case(50, 5)]
|
||||
#[tokio::test]
|
||||
async fn test_client_count_scaling(#[case] num_clients: usize, #[case] events_per_client: usize) {
|
||||
let socket_path = test_socket_path(&format!("scaling_{}_{}", num_clients, events_per_client));
|
||||
let manager = Arc::new(EventManager::new().unwrap());
|
||||
|
||||
// Start the listener
|
||||
let listener_manager = Arc::clone(&manager);
|
||||
let socket_path_clone = socket_path.clone();
|
||||
tokio::spawn(async move {
|
||||
listener_manager.start_listening(socket_path_clone).await;
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
|
||||
// Connect many clients
|
||||
let mut readers = Vec::new();
|
||||
for _ in 0..num_clients {
|
||||
let stream = UnixStream::connect(&socket_path).await.unwrap();
|
||||
readers.push(BufReader::new(stream));
|
||||
}
|
||||
|
||||
// Broadcast events
|
||||
for i in 0..events_per_client {
|
||||
let event = Event::new(format!("scale event {}", i));
|
||||
manager.broadcast(&event).await.unwrap();
|
||||
}
|
||||
|
||||
// Verify all clients received all events
|
||||
for reader in &mut readers {
|
||||
let events = read_events_with_timeout(reader, events_per_client, 200).await;
|
||||
assert_eq!(events.len(), events_per_client);
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
let _ = std::fs::remove_file(&socket_path);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_client_disconnect_doesnt_affect_others() {
|
||||
let socket_path = test_socket_path("disconnect");
|
||||
let manager = Arc::new(EventManager::new().unwrap());
|
||||
|
||||
// Start the listener
|
||||
let listener_manager = Arc::clone(&manager);
|
||||
let socket_path_clone = socket_path.clone();
|
||||
tokio::spawn(async move {
|
||||
listener_manager.start_listening(socket_path_clone).await;
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
|
||||
// Connect 3 clients
|
||||
let stream1 = UnixStream::connect(&socket_path).await.unwrap();
|
||||
let stream2 = UnixStream::connect(&socket_path).await.unwrap();
|
||||
let stream3 = UnixStream::connect(&socket_path).await.unwrap();
|
||||
|
||||
let mut reader1 = BufReader::new(stream1);
|
||||
let mut reader2 = BufReader::new(stream2);
|
||||
let mut reader3 = BufReader::new(stream3);
|
||||
|
||||
// Broadcast initial event
|
||||
manager
|
||||
.broadcast(&Event::new("before disconnect"))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// All receive it
|
||||
let _ = read_events_with_timeout(&mut reader1, 1, 100).await;
|
||||
let _ = read_events_with_timeout(&mut reader2, 1, 100).await;
|
||||
let _ = read_events_with_timeout(&mut reader3, 1, 100).await;
|
||||
|
||||
// Drop client 2 (simulates disconnect)
|
||||
drop(reader2);
|
||||
|
||||
// Broadcast another event
|
||||
manager
|
||||
.broadcast(&Event::new("after disconnect"))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Clients 1 and 3 should still receive it
|
||||
let events1 = read_events_with_timeout(&mut reader1, 1, 100).await;
|
||||
let events3 = read_events_with_timeout(&mut reader3, 1, 100).await;
|
||||
|
||||
assert_eq!(events1.len(), 1);
|
||||
assert_eq!(events3.len(), 1);
|
||||
assert!(events1[0].contains("after disconnect"));
|
||||
assert!(events3[0].contains("after disconnect"));
|
||||
|
||||
// Cleanup
|
||||
let _ = std::fs::remove_file(&socket_path);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_deserialization_of_received_events() {
|
||||
let socket_path = test_socket_path("json_deser");
|
||||
let manager = Arc::new(EventManager::new().unwrap());
|
||||
|
||||
// Start the listener
|
||||
let listener_manager = Arc::clone(&manager);
|
||||
let socket_path_clone = socket_path.clone();
|
||||
tokio::spawn(async move {
|
||||
listener_manager.start_listening(socket_path_clone).await;
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
|
||||
// Broadcast an event with special characters
|
||||
let test_message = "special chars: @#$% newline\\n tab\\t quotes \"test\"";
|
||||
manager.broadcast(&Event::new(test_message)).await.unwrap();
|
||||
|
||||
// Connect and deserialize
|
||||
let stream = UnixStream::connect(&socket_path).await.unwrap();
|
||||
let mut reader = BufReader::new(stream);
|
||||
|
||||
let mut line = String::new();
|
||||
reader.read_line(&mut line).await.unwrap();
|
||||
|
||||
// Should be valid JSON
|
||||
let parsed: serde_json::Value = serde_json::from_str(&line.trim()).unwrap();
|
||||
|
||||
assert_eq!(parsed["message"], test_message);
|
||||
|
||||
// Cleanup
|
||||
let _ = std::fs::remove_file(&socket_path);
|
||||
}
|
||||
Reference in New Issue
Block a user