From 484d0a0bf2e0e1c5e3d76841b93f6f65207cfd71 Mon Sep 17 00:00:00 2001 From: hheik <4469778+hheik@users.noreply.github.com> Date: Wed, 21 Feb 2024 19:26:26 +0200 Subject: [PATCH] Implemented client/server protocol --- src/client.rs | 33 +++++- src/client/app.rs | 50 +++++---- src/client/crossterm.rs | 43 ++++---- src/client/request_queue.rs | 94 +++++++++++++++++ src/lib.rs | 203 +----------------------------------- src/main.rs | 4 + src/os_unix.rs | 92 ++++++++++++++++ src/protocol.rs | 130 +++++++++++++++++++++++ src/server.rs | 88 ++++++++++++---- 9 files changed, 471 insertions(+), 266 deletions(-) create mode 100644 src/client/request_queue.rs create mode 100644 src/os_unix.rs create mode 100644 src/protocol.rs diff --git a/src/client.rs b/src/client.rs index 859a9a1..22ac1e2 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,15 +1,38 @@ -use std::{error::Error, time::Duration}; +use std::{ + error::Error, + sync::{Arc, Mutex}, + time::Duration, +}; use crate::CliArgs; +use self::{app::App, request_queue::request_queue_cleaner}; + pub mod app; pub mod crossterm; +pub mod request_queue; pub mod ui; pub fn run(args: CliArgs) -> Result<(), Box> { - crossterm::run( - Duration::from_millis(args.tick_rate), - args.enhanced_graphics, - )?; + let message_queue = Arc::new(Mutex::new(vec![])); + let server_state = Arc::new(Mutex::new(None)); + let app = App { + title: "rmp - Rust Music Player".into(), + enhanced_graphics: args.enhanced_graphics, + should_quit: false, + state: server_state.clone(), + message_queue: message_queue.clone(), + }; + let thread_builder = std::thread::Builder::new().name("request_queue".into()); + thread_builder + .spawn(move || { + request_queue_cleaner( + Duration::from_millis(args.message_rate), + message_queue.clone(), + server_state.clone(), + ) + }) + .unwrap(); + crossterm::run(app, Duration::from_millis(args.tick_rate))?; Ok(()) } diff --git a/src/client/app.rs b/src/client/app.rs index 733ed68..b65e11c 100644 --- a/src/client/app.rs +++ b/src/client/app.rs @@ -1,55 +1,61 @@ -use std::time::Duration; +use std::{ + sync::{Arc, Mutex}, + time::Duration, +}; -use interprocess::local_socket::LocalSocketStream; -use rmp::protocol::{Message, MessageType}; +use rmp::{ + protocol::{Message, MessageType}, + ServerState, +}; pub struct App { - pub socket: Option, pub title: String, pub should_quit: bool, pub enhanced_graphics: bool, + pub message_queue: Arc>>, + pub state: Arc>>, } impl App { pub fn new(title: &str, enhanced_graphics: bool) -> Self { Self { - socket: None, title: title.to_string(), should_quit: false, enhanced_graphics, + message_queue: Arc::new(Mutex::new(vec![])), + state: Arc::new(Mutex::new(None)), } } - pub fn connect(&mut self) -> Result<(), ()> { - let path = rmp::os::get_socket_path().map_err(|_| ())?; - let socket = LocalSocketStream::connect(path).map_err(|_| ())?; - self.socket = Some(socket); - Ok(()) + fn push_message(&mut self, message: Message) { + self.message_queue.lock().unwrap().push(message); } pub fn connected(&self) -> bool { - self.socket.is_some() + self.state.lock().unwrap().is_some() } - pub fn toggle_shuffle(&mut self) {} + pub fn toggle_shuffle(&mut self) { + self.push_message(Message::new(MessageType::ToggleSuffle, None)); + } - pub fn toggle_next(&mut self) {} + pub fn toggle_next(&mut self) { + self.push_message(Message::new(MessageType::ToggleNext, None)); + } - pub fn toggle_repeat(&mut self) {} + pub fn toggle_repeat(&mut self) { + self.push_message(Message::new(MessageType::ToggleRepeat, None)); + } pub fn fetch_state(&mut self) { - let mut socket = self.socket.as_mut().unwrap(); - Message::new(MessageType::FetchState, None) - .send(&mut socket) - .unwrap(); + self.push_message(Message::new(MessageType::StateFetch, None)); } pub fn on_key(&mut self, key: char) { match key { - 's' => self.toggle_shuffle(), - 'n' => self.toggle_next(), - 'r' => self.toggle_repeat(), - ' ' => self.fetch_state(), + 'S' => self.toggle_shuffle(), + 'X' => self.toggle_next(), + 'R' => self.toggle_repeat(), 'q' => self.should_quit = true, _ => (), } diff --git a/src/client/crossterm.rs b/src/client/crossterm.rs index ef97c08..bccb149 100644 --- a/src/client/crossterm.rs +++ b/src/client/crossterm.rs @@ -13,7 +13,7 @@ use ratatui::prelude::*; use super::{app::App, ui}; -pub fn run(tick_rate: Duration, enhanced_graphics: bool) -> Result<(), Box> { +pub fn run(app: App, tick_rate: Duration) -> Result<(), Box> { // setup terminal enable_raw_mode()?; let mut stdout = io::stdout(); @@ -21,8 +21,6 @@ pub fn run(tick_rate: Duration, enhanced_graphics: bool) -> Result<(), Box( loop { terminal.draw(|f| ui::draw(f, &mut app))?; - if !app.connected() { - match app.connect() { - Ok(_) => (), - Err(_) => (), - } - continue; - } - let timeout = tick_rate .checked_sub(last_tick.elapsed()) .unwrap_or_else(|| Duration::from_secs(0)); if crossterm::event::poll(timeout)? { if let Event::Key(key) = event::read()? { if key.kind == KeyEventKind::Press { - match key.code { - KeyCode::Char(c) => app.on_key(c), - KeyCode::Left => app.on_left(), - KeyCode::Up => app.on_up(), - KeyCode::Right => app.on_right(), - KeyCode::Down => app.on_down(), - KeyCode::Enter => app.on_enter(), - KeyCode::Tab => app.on_tab(), - _ => {} + if app.connected() { + match key.code { + KeyCode::Char(c) => app.on_key(c), + KeyCode::Left => app.on_left(), + KeyCode::Up => app.on_up(), + KeyCode::Right => app.on_right(), + KeyCode::Down => app.on_down(), + KeyCode::Enter => app.on_enter(), + KeyCode::Tab => app.on_tab(), + _ => {} + } + } else { + // Allow quitting while in "Not connected" screen + match key.code { + KeyCode::Char(c) => { + if c == 'q' { + app.should_quit = true; + } + } + _ => (), + } } } } } + if last_tick.elapsed() >= tick_rate { app.on_tick(last_tick.elapsed()); last_tick = Instant::now(); diff --git a/src/client/request_queue.rs b/src/client/request_queue.rs new file mode 100644 index 0000000..302785e --- /dev/null +++ b/src/client/request_queue.rs @@ -0,0 +1,94 @@ +use std::{ + sync::{Arc, Mutex}, + time::{Duration, Instant}, +}; + +use interprocess::local_socket::LocalSocketStream; +use rmp::{ + protocol::{self, Message, MessageError, MessageType}, + ServerState, +}; + +pub fn request_queue_cleaner( + message_rate: Duration, + queue: Arc>>, + state: Arc>>, +) { + let mut last_tick = Instant::now(); + let mut should_connect = true; + let mut stream: Option = None; + loop { + if should_connect { + *state.lock().unwrap() = None; + stream = Some(connect()); + should_connect = false; + } + match stream.as_mut() { + Some(mut stream) => { + for request in queue.lock().unwrap().drain(..) { + if let Err(_) = protocol::send(&mut stream, &request) { + should_connect = true; + continue; + } + if let Ok(response) = protocol::receive(&mut stream) { + match route_response(&response, &mut state.lock().unwrap()) { + Err(error) => { + eprintln!("{error:?}") + } + _ => {} + } + } + } + + // HACK: keep updating state + queue + .lock() + .unwrap() + .push(Message::new(MessageType::StateFetch, None)); + + let sleep_duration = message_rate + .checked_sub(last_tick.elapsed()) + .unwrap_or_else(|| Duration::from_secs(0)); + std::thread::sleep(sleep_duration); + last_tick = Instant::now(); + } + None => should_connect = true, + } + } +} + +/// Blocks thread until connected to socket +fn connect() -> LocalSocketStream { + let path = rmp::os::get_socket_path().unwrap(); + loop { + match LocalSocketStream::connect(path.clone()) { + Ok(stream) => return stream, + Err(_) => {} + } + std::thread::sleep(Duration::from_millis(100)); + } +} + +fn route_response(response: &Message, state: &mut Option) -> Result<(), String> { + match response.message_type { + MessageType::StateResponse => { + let body = response.body.as_ref().ok_or("Missing response body")?; + let response: ServerState = + bincode::deserialize(&body).map_err(|err| err.to_string())?; + *state = Some(response); + } + MessageType::NotImplementedAck => { + eprintln!("Server doesn't implement message") + } + MessageType::ProtocolError => { + let body = response.body.as_ref().ok_or("Missing response body")?; + let response: MessageError = + bincode::deserialize(&body).map_err(|err| err.to_string())?; + eprintln!("Server claims protocol error: {response:?}"); + } + message_type => { + eprintln!("Message handling not implemented for client: {message_type:?}"); + } + } + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index ce3f9f7..a238e6c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,11 @@ use std::path::PathBuf; use serde::{Deserialize, Serialize}; +#[cfg(target_family = "unix")] +#[path = "os_unix.rs"] +pub mod os; +pub mod protocol; + #[derive(Serialize, Deserialize, Debug, Default)] pub struct ServerState { pub playlist_params: PlaylistParams, @@ -52,108 +57,6 @@ impl Default for PlaylistParams { } } -/// Protocol for client-server communication -pub mod protocol { - use std::io::{Read, Write}; - - use interprocess::local_socket::LocalSocketStream; - use serde::{Deserialize, Serialize}; - - /// Prefix messages with this header - pub const HEADER_MAGIC: [u8; 4] = [0xCA, 0xFE, 0xBA, 0xBE]; - /// Maximum allowed body size - pub const MAX_BODY_LENGTH: usize = 10 * 1024 * 1024; - - #[derive(Debug)] - pub enum MessageError { - HeaderMismatch, - BodySizeLimit, - ChecksumMismatch, - ReadError, - DeserializationError, - } - - #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Deserialize, Serialize)] - pub enum MessageType { - FetchState = 0, - FetchStateAck, - } - - #[derive(Debug, Serialize, Deserialize)] - pub struct Message { - pub message_type: MessageType, - pub body: Option>, - } - - impl Message { - pub fn new(message_type: MessageType, body: Option<&[u8]>) -> Self { - Self { - message_type, - body: body.map(|b| Vec::from(b)), - } - } - - /// Message format (values are in little-endian): - /// offset | size | explanation - /// -------+------+----------- - /// 0x00 | u32 | HEADER_MAGIC - /// 0x04 | u32 | Body checksum - /// 0x08 | u32 | Body length - /// 0x12 | ? | Body - fn as_bytes(&self) -> Vec { - let magic = &HEADER_MAGIC[..]; - let body = &bincode::serialize(self).unwrap(); - let checksum = &crc32fast::hash(&body).to_le_bytes(); - let body_length = &(body.len() as u32).to_le_bytes(); - [magic, checksum, body_length, body].concat() - } - - pub fn send(&self, stream: &mut LocalSocketStream) -> Result<(), std::io::Error> { - let bytes = self.as_bytes(); - stream.write_all(&bytes)?; - Ok(()) - } - } - - pub fn parse_stream(stream: &mut LocalSocketStream) -> Result { - let mut magic_buffer = vec![0; HEADER_MAGIC.len()]; - if let Err(_) = stream.read_exact(&mut magic_buffer) { - return Err(MessageError::ReadError); - } - if magic_buffer != HEADER_MAGIC { - return Err(MessageError::HeaderMismatch); - } - - let mut checksum_buffer = [0; 4]; - if let Err(_) = stream.read_exact(&mut checksum_buffer) { - return Err(MessageError::ReadError); - } - let expected_checksum = u32::from_le_bytes(checksum_buffer); - - let mut body_length_buffer = [0; 4]; - if let Err(_) = stream.read_exact(&mut body_length_buffer) { - return Err(MessageError::ReadError); - } - let expected_body_length = u32::from_le_bytes(body_length_buffer) as usize; - - if expected_body_length > MAX_BODY_LENGTH { - return Err(MessageError::BodySizeLimit); - } - - let mut body_buffer = vec![0; expected_body_length]; - if let Err(_) = stream.read_exact(&mut body_buffer) { - return Err(MessageError::ReadError); - } - - if crc32fast::hash(&body_buffer) != expected_checksum { - return Err(MessageError::ChecksumMismatch); - } - - println!("Stream data:\n\t{magic_buffer:?}\n\t{checksum_buffer:?}\n\t{body_length_buffer:?}\n\t{body_buffer:?}"); - bincode::deserialize(&body_buffer).map_err(|_| MessageError::DeserializationError) - } -} - pub mod server { use std::{fmt::Debug, path::PathBuf}; @@ -171,99 +74,3 @@ pub mod server { } } } - -#[cfg(target_family = "unix")] -pub mod os { - use std::{ - fs, - path::{Path, PathBuf}, - process::{id, Command, Stdio}, - }; - - use super::server::ServerError; - - pub fn reserve_pid() -> Result<(), ServerError> { - let pid_path = get_pid_path()?; - is_running()?; - - fs::write(&pid_path, id().to_string()).map_err(|err| ServerError::Io(err))?; - Command::new("chmod") - .args(&["600", &pid_path.to_string_lossy()]) - .output() - .map_err(|err| ServerError::Io(err))?; - Ok(()) - } - - pub fn is_running() -> Result { - let pid_path = get_pid_path()?; - - match fs::read(&pid_path) { - Ok(old_pid) => { - let old_pid = - String::from_utf8(old_pid).map_err(|err| ServerError::from_debuggable(err))?; - let old_pid = old_pid.trim(); - Ok(Command::new("ps") - .args(&["-p", old_pid]) - .stdout(Stdio::null()) - .stderr(Stdio::null()) - .status() - .map_err(|err| ServerError::Io(err))? - .success()) - } - _ => Ok(false), - } - } - - pub fn run_in_background() -> Result<(), ServerError> { - let this = std::env::args().next().unwrap(); - Command::new(this) - .stdout(Stdio::null()) - .stderr(Stdio::null()) - .args(&["-s"]) - .spawn() - .map_err(|err| ServerError::Io(err))?; - Ok(()) - } - - pub fn kill() -> Result<(), ServerError> { - let pid_path = get_pid_path()?; - let socket_path = get_socket_path()?; - let pid = String::from_utf8(fs::read(&pid_path).map_err(|err| ServerError::Io(err))?) - .map_err(|err| ServerError::from_debuggable(err))?; - let pid = pid.trim(); - Command::new("kill") - .arg(pid) - .spawn() - .map_err(|err| ServerError::Io(err))?; - Command::new("rm") - .args(&[ - "-f", - &pid_path.to_string_lossy(), - &socket_path.to_string_lossy(), - ]) - .spawn() - .map_err(|err| ServerError::Io(err))?; - Ok(()) - } - - pub fn get_socket_path() -> Result { - Ok(get_runtime_dir()?.join("rmp.socket")) - } - - fn get_runtime_dir() -> Result { - let uid = String::from_utf8( - Command::new("id") - .arg("-u") - .output() - .map_err(|err| ServerError::Io(err))? - .stdout, - ) - .map_err(|err| ServerError::from_debuggable(err))?; - let dir = Path::new("/run/user").join(uid.trim().to_string()); - Ok(dir) - } - - fn get_pid_path() -> Result { - Ok(get_runtime_dir()?.join("rmp.pid")) - } -} diff --git a/src/main.rs b/src/main.rs index 4d37ce5..34cc0b2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -34,6 +34,10 @@ pub struct CliArgs { #[argh(option, default = "250")] tick_rate: u64, + /// interval in ms for clearing the request queue. + #[argh(option, default = "50")] + message_rate: u64, + /// whether unicode symbols are used to improve the overall look of the app #[argh(option, default = "true")] enhanced_graphics: bool, diff --git a/src/os_unix.rs b/src/os_unix.rs new file mode 100644 index 0000000..c02a001 --- /dev/null +++ b/src/os_unix.rs @@ -0,0 +1,92 @@ +use std::{ + fs, + path::{Path, PathBuf}, + process::{id, Command, Stdio}, +}; + +use super::server::ServerError; + +pub fn reserve_pid() -> Result<(), ServerError> { + let pid_path = get_pid_path()?; + is_running()?; + + fs::write(&pid_path, id().to_string()).map_err(|err| ServerError::Io(err))?; + Command::new("chmod") + .args(&["600", &pid_path.to_string_lossy()]) + .output() + .map_err(|err| ServerError::Io(err))?; + Ok(()) +} + +pub fn is_running() -> Result { + let pid_path = get_pid_path()?; + + match fs::read(&pid_path) { + Ok(old_pid) => { + let old_pid = + String::from_utf8(old_pid).map_err(|err| ServerError::from_debuggable(err))?; + let old_pid = old_pid.trim(); + Ok(Command::new("ps") + .args(&["-p", old_pid]) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .status() + .map_err(|err| ServerError::Io(err))? + .success()) + } + _ => Ok(false), + } +} + +pub fn run_in_background() -> Result<(), ServerError> { + let this = std::env::args().next().unwrap(); + Command::new(this) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .args(&["-s"]) + .spawn() + .map_err(|err| ServerError::Io(err))?; + Ok(()) +} + +pub fn kill() -> Result<(), ServerError> { + let pid_path = get_pid_path()?; + let socket_path = get_socket_path()?; + let pid = String::from_utf8(fs::read(&pid_path).map_err(|err| ServerError::Io(err))?) + .map_err(|err| ServerError::from_debuggable(err))?; + let pid = pid.trim(); + Command::new("kill") + .arg(pid) + .spawn() + .map_err(|err| ServerError::Io(err))?; + Command::new("rm") + .args(&[ + "-f", + &pid_path.to_string_lossy(), + &socket_path.to_string_lossy(), + ]) + .spawn() + .map_err(|err| ServerError::Io(err))?; + Ok(()) +} + +pub fn get_socket_path() -> Result { + Ok(get_runtime_dir()?.join("rmp.socket")) +} + +fn get_runtime_dir() -> Result { + let uid = String::from_utf8( + Command::new("id") + .arg("-u") + .output() + .map_err(|err| ServerError::Io(err))? + .stdout, + ) + .map_err(|err| ServerError::from_debuggable(err))?; + let dir = Path::new("/run/user").join(uid.trim().to_string()); + Ok(dir) +} + +fn get_pid_path() -> Result { + Ok(get_runtime_dir()?.join("rmp.pid")) +} diff --git a/src/protocol.rs b/src/protocol.rs new file mode 100644 index 0000000..09a6796 --- /dev/null +++ b/src/protocol.rs @@ -0,0 +1,130 @@ +use std::{ + fmt::Display, + io::{Read, Write}, +}; + +use interprocess::local_socket::LocalSocketStream; +use serde::{Deserialize, Serialize}; + +use crate::ServerState; + +/// Prefix messages with this header +pub const HEADER_MAGIC: [u8; 4] = [0xCA, 0xFE, 0xBA, 0xBE]; +/// Maximum allowed body size +pub const MAX_BODY_LENGTH: usize = 10 * 1024 * 1024; + +#[derive(Debug, Serialize, Deserialize)] +pub enum MessageError { + HeaderMismatch, + BodySizeLimit, + ChecksumMismatch, + ReadError, + DeserializationError, +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Deserialize, Serialize)] +pub enum MessageType { + /// Generic acknowledge + Ack = 0, + /// client/server did not know how to handle the request + NotImplementedAck, + /// Request was invalid + ProtocolError, + StateFetch, + StateResponse, + ToggleSuffle, + ToggleNext, + ToggleRepeat, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Message { + pub message_type: MessageType, + pub body: Option>, +} + +impl Message { + pub fn new(message_type: MessageType, body: Option<&[u8]>) -> Self { + Self { + message_type, + body: body.map(|b| Vec::from(b)), + } + } + + /// Message format (values are in little-endian): + /// offset | size | explanation + /// -------+------+----------- + /// 0x00 | u32 | HEADER_MAGIC + /// 0x04 | u32 | Body checksum + /// 0x08 | u32 | Body length + /// 0x12 | ? | Body + fn as_bytes(&self) -> Vec { + let magic = &HEADER_MAGIC[..]; + let body = &bincode::serialize(self).unwrap(); + let checksum = &crc32fast::hash(&body).to_le_bytes(); + let body_length = &(body.len() as u32).to_le_bytes(); + [magic, checksum, body_length, body].concat() + } + + pub fn state_response(server_state: &ServerState) -> Result { + Ok(Self { + message_type: MessageType::StateResponse, + body: Some(bincode::serialize(server_state).map_err(|err| err.to_string())?), + }) + } +} + +impl Display for Message { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{:?}\t{}", + self.message_type, + self.body + .as_ref() + .map_or("(no body)".into(), |body| format!("{} B", body.len())) + ) + } +} + +pub fn send(stream: &mut LocalSocketStream, message: &Message) -> Result<(), std::io::Error> { + stream.write_all(&message.as_bytes())?; + Ok(()) +} + +pub fn receive(stream: &mut LocalSocketStream) -> Result { + let mut magic_buffer = vec![0; HEADER_MAGIC.len()]; + if let Err(_) = stream.read_exact(&mut magic_buffer) { + return Err(MessageError::ReadError); + } + if magic_buffer != HEADER_MAGIC { + return Err(MessageError::HeaderMismatch); + } + + let mut checksum_buffer = [0; 4]; + if let Err(_) = stream.read_exact(&mut checksum_buffer) { + return Err(MessageError::ReadError); + } + let expected_checksum = u32::from_le_bytes(checksum_buffer); + + let mut body_length_buffer = [0; 4]; + if let Err(_) = stream.read_exact(&mut body_length_buffer) { + return Err(MessageError::ReadError); + } + let expected_body_length = u32::from_le_bytes(body_length_buffer) as usize; + + if expected_body_length > MAX_BODY_LENGTH { + return Err(MessageError::BodySizeLimit); + } + + let mut body_buffer = vec![0; expected_body_length]; + if let Err(_) = stream.read_exact(&mut body_buffer) { + return Err(MessageError::ReadError); + } + + if crc32fast::hash(&body_buffer) != expected_checksum { + return Err(MessageError::ChecksumMismatch); + } + + bincode::deserialize(&body_buffer).map_err(|_| MessageError::DeserializationError) +} diff --git a/src/server.rs b/src/server.rs index 1ec0ecc..8c6f38f 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,22 +1,28 @@ use interprocess::local_socket::{LocalSocketListener, LocalSocketStream}; -use rmp::{os, server::ServerError, PlaylistParams, ServerState}; -use std::{fs, path::PathBuf}; + +use rmp::{ + os, + protocol::{Message, MessageType}, + server::ServerError, + PlaylistParams, ServerState, +}; +use std::{ + fs, + sync::{Arc, Mutex}, +}; use crate::CliArgs; pub mod audio_backend; +#[derive(Debug)] pub struct Server { - pub socket: LocalSocketListener, pub state: ServerState, } impl Server { - pub fn new(socket_path: PathBuf, state: ServerState) -> Result { - Ok(Self { - socket: LocalSocketListener::bind(socket_path).map_err(|err| ServerError::Io(err))?, - state, - }) + pub fn from_state(state: ServerState) -> Self { + Self { state } } } @@ -64,15 +70,19 @@ fn serve(state: ServerState) -> Result<(), ServerError> { if socket_path.exists() { fs::remove_file(&socket_path).map_err(|err| ServerError::Io(err))?; } - println!("state: {state:?}"); - let server = Server::new(socket_path, state)?; + let socket = LocalSocketListener::bind(socket_path).map_err(|err| ServerError::Io(err))?; + let server = Arc::new(Mutex::new(Server::from_state(state))); println!("Waiting for connections..."); - for message in server.socket.incoming() { + let mut session_counter = 0; + for message in socket.incoming() { match message { Ok(stream) => { - let thread_builder = std::thread::Builder::new().name("session_handler".into()); + session_counter += 1; + let thread_builder = + std::thread::Builder::new().name(format!("session_{session_counter}")); + let server = server.clone(); thread_builder - .spawn(move || session_handler(stream)) + .spawn(move || session_handler(stream, server)) .unwrap(); } Err(err) => { @@ -85,21 +95,35 @@ fn serve(state: ServerState) -> Result<(), ServerError> { Ok(()) } -fn session_handler(mut stream: LocalSocketStream) { - let thread_id = std::thread::current().id(); - println!("session created: {thread_id:?}"); +fn session_handler(mut stream: LocalSocketStream, server: Arc>) { + let thread = std::thread::current(); + let session_id = thread.name().unwrap_or(""); + println!("[{session_id}] session created"); loop { - match rmp::protocol::parse_stream(&mut stream) { - Ok(body) => { - println!("Message: {body:?}") + match rmp::protocol::receive(&mut stream) { + Ok(message) => { + println!("[{session_id}] rx {message}"); + match route_request(&message, &mut server.lock().unwrap()) { + Err(err) => { + eprintln!("[{session_id}] rx Error: {err}"); + } + Ok(response) => { + println!("[{session_id}] tx {response}"); + rmp::protocol::send(&mut stream, &response).unwrap(); + } + } } Err(error) => match error { rmp::protocol::MessageError::ReadError => { - println!("session terminated: {thread_id:?}"); + println!("[{session_id}] session terminated"); return; } error => { - eprintln!("Message error in {thread_id:?}: {error:?}") + eprintln!("[{session_id}] rx {error:?}"); + let body = bincode::serialize(&error).unwrap(); + let message = Message::new(MessageType::ProtocolError, Some(&body)); + println!("[{session_id}] tx {message}"); + rmp::protocol::send(&mut stream, &message).unwrap(); } }, } @@ -126,3 +150,25 @@ fn handle_error(err: ServerError) -> i32 { } } } + +fn route_request(request: &Message, server: &mut Server) -> Result { + match request.message_type { + MessageType::StateFetch => { + return Message::state_response(&server.state); + } + MessageType::ToggleNext => { + server.state.playlist_params.next = !server.state.playlist_params.next; + return Message::state_response(&server.state); + } + MessageType::ToggleSuffle => { + server.state.playlist_params.shuffle = !server.state.playlist_params.shuffle; + return Message::state_response(&server.state); + } + MessageType::ToggleRepeat => { + server.state.playlist_params.repeat = !server.state.playlist_params.repeat; + return Message::state_response(&server.state); + } + _ => {} + } + Ok(Message::new(MessageType::NotImplementedAck, None)) +}