rmp/src/protocol.rs

133 lines
3.8 KiB
Rust

use std::{
fmt::Display,
io::{Read, Write},
};
use interprocess::local_socket::LocalSocketStream;
use serde::{Deserialize, Serialize};
use crate::ServerState;
/// Prefix messages with this header
pub const HEADER_MAGIC: [u8; 4] = [0xCA, 0xFE, 0xBA, 0xBE];
/// Maximum allowed body size
pub const MAX_BODY_LENGTH: usize = 10 * 1024 * 1024;
#[derive(Debug, Serialize, Deserialize)]
pub enum MessageError {
HeaderMismatch,
BodySizeLimit,
ChecksumMismatch,
ReadError,
DeserializationError,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Deserialize, Serialize)]
pub enum MessageType {
/// Generic acknowledge
Ack = 0,
/// client/server did not know how to handle the request
NotImplementedAck,
/// Request was invalid
ProtocolError,
StateFetch,
StateResponse,
ToggleSuffle,
ToggleNext,
ToggleRepeat,
PlayTrack,
TogglePause,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Message {
pub message_type: MessageType,
pub body: Option<Vec<u8>>,
}
impl Message {
pub fn new(message_type: MessageType, body: Option<&[u8]>) -> Self {
Self {
message_type,
body: body.map(|b| Vec::from(b)),
}
}
/// Message format (values are in little-endian):
/// offset | size | explanation
/// -------+------+-----------
/// 0x00 | u32 | HEADER_MAGIC
/// 0x04 | u32 | Body checksum
/// 0x08 | u32 | Body length
/// 0x12 | ? | Body
fn as_bytes(&self) -> Vec<u8> {
let magic = &HEADER_MAGIC[..];
let body = &bincode::serialize(self).unwrap();
let checksum = &crc32fast::hash(&body).to_le_bytes();
let body_length = &(body.len() as u32).to_le_bytes();
[magic, checksum, body_length, body].concat()
}
pub fn state_response(server_state: &ServerState) -> Result<Self, String> {
Ok(Self {
message_type: MessageType::StateResponse,
body: Some(bincode::serialize(server_state).map_err(|err| err.to_string())?),
})
}
}
impl Display for Message {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{:?}\t{}",
self.message_type,
self.body
.as_ref()
.map_or("(no body)".into(), |body| format!("{} B", body.len()))
)
}
}
pub fn send(stream: &mut LocalSocketStream, message: &Message) -> Result<(), std::io::Error> {
stream.write_all(&message.as_bytes())?;
Ok(())
}
pub fn receive(stream: &mut LocalSocketStream) -> Result<Message, MessageError> {
let mut magic_buffer = vec![0; HEADER_MAGIC.len()];
if let Err(_) = stream.read_exact(&mut magic_buffer) {
return Err(MessageError::ReadError);
}
if magic_buffer != HEADER_MAGIC {
return Err(MessageError::HeaderMismatch);
}
let mut checksum_buffer = [0; 4];
if let Err(_) = stream.read_exact(&mut checksum_buffer) {
return Err(MessageError::ReadError);
}
let expected_checksum = u32::from_le_bytes(checksum_buffer);
let mut body_length_buffer = [0; 4];
if let Err(_) = stream.read_exact(&mut body_length_buffer) {
return Err(MessageError::ReadError);
}
let expected_body_length = u32::from_le_bytes(body_length_buffer) as usize;
if expected_body_length > MAX_BODY_LENGTH {
return Err(MessageError::BodySizeLimit);
}
let mut body_buffer = vec![0; expected_body_length];
if let Err(_) = stream.read_exact(&mut body_buffer) {
return Err(MessageError::ReadError);
}
if crc32fast::hash(&body_buffer) != expected_checksum {
return Err(MessageError::ChecksumMismatch);
}
bincode::deserialize(&body_buffer).map_err(|_| MessageError::DeserializationError)
}