use std::{ fmt::Display, io::{Read, Write}, }; use interprocess::local_socket::LocalSocketStream; use serde::{Deserialize, Serialize}; use crate::ServerState; /// Prefix messages with this header pub const HEADER_MAGIC: [u8; 4] = [0xCA, 0xFE, 0xBA, 0xBE]; /// Maximum allowed body size pub const MAX_BODY_LENGTH: usize = 10 * 1024 * 1024; #[derive(Debug, Serialize, Deserialize)] pub enum MessageError { HeaderMismatch, BodySizeLimit, ChecksumMismatch, ReadError, DeserializationError, } #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Deserialize, Serialize)] pub enum MessageType { /// Generic acknowledge Ack = 0, /// client/server did not know how to handle the request NotImplementedAck, /// Request was invalid ProtocolError, StateFetch, StateResponse, ToggleSuffle, ToggleNext, ToggleRepeat, PlayTrack, TogglePause, } #[derive(Debug, Serialize, Deserialize)] pub struct Message { pub message_type: MessageType, pub body: Option>, } impl Message { pub fn new(message_type: MessageType, body: Option<&[u8]>) -> Self { Self { message_type, body: body.map(|b| Vec::from(b)), } } /// Message format (values are in little-endian): /// offset | size | explanation /// -------+------+----------- /// 0x00 | u32 | HEADER_MAGIC /// 0x04 | u32 | Body checksum /// 0x08 | u32 | Body length /// 0x12 | ? | Body fn as_bytes(&self) -> Vec { let magic = &HEADER_MAGIC[..]; let body = &bincode::serialize(self).unwrap(); let checksum = &crc32fast::hash(&body).to_le_bytes(); let body_length = &(body.len() as u32).to_le_bytes(); [magic, checksum, body_length, body].concat() } pub fn state_response(server_state: &ServerState) -> Result { Ok(Self { message_type: MessageType::StateResponse, body: Some(bincode::serialize(server_state).map_err(|err| err.to_string())?), }) } } impl Display for Message { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "{:?}\t{}", self.message_type, self.body .as_ref() .map_or("(no body)".into(), |body| format!("{} B", body.len())) ) } } pub fn send(stream: &mut LocalSocketStream, message: &Message) -> Result<(), std::io::Error> { stream.write_all(&message.as_bytes())?; Ok(()) } pub fn receive(stream: &mut LocalSocketStream) -> Result { let mut magic_buffer = vec![0; HEADER_MAGIC.len()]; if let Err(_) = stream.read_exact(&mut magic_buffer) { return Err(MessageError::ReadError); } if magic_buffer != HEADER_MAGIC { return Err(MessageError::HeaderMismatch); } let mut checksum_buffer = [0; 4]; if let Err(_) = stream.read_exact(&mut checksum_buffer) { return Err(MessageError::ReadError); } let expected_checksum = u32::from_le_bytes(checksum_buffer); let mut body_length_buffer = [0; 4]; if let Err(_) = stream.read_exact(&mut body_length_buffer) { return Err(MessageError::ReadError); } let expected_body_length = u32::from_le_bytes(body_length_buffer) as usize; if expected_body_length > MAX_BODY_LENGTH { return Err(MessageError::BodySizeLimit); } let mut body_buffer = vec![0; expected_body_length]; if let Err(_) = stream.read_exact(&mut body_buffer) { return Err(MessageError::ReadError); } if crc32fast::hash(&body_buffer) != expected_checksum { return Err(MessageError::ChecksumMismatch); } bincode::deserialize(&body_buffer).map_err(|_| MessageError::DeserializationError) }