1use std::collections::VecDeque;
3use std::fmt::Debug;
4use std::io::{Cursor, ErrorKind};
5#[cfg(test)]
6use std::net::{SocketAddr, TcpListener, TcpStream};
7use std::option::Option;
8
9use bson::doc;
11use bson::document::ValueAccessError;
12
13use crate::byte_counter::ByteCounter;
14
15#[derive(Debug, Eq, PartialEq)]
17pub enum ErrorCode {
18 BsonParseFailed,
20 MessageTooBig,
23 MessageParseFailed,
25 MessageVersionIncompatible,
27 SectionIdUnknown,
29 SectionParseFailed,
32 RequestCookieInvalid,
34 RequestNamespaceInvalid,
36 RequestFunctionInvalid,
38 RequestVersionInvalid,
40 ResponseCookieInvalid,
42 ResponseStateInvalid,
44 Runtime(i32),
46 Unknown(i32),
48}
49
50#[derive(thiserror::Error, Debug)]
52pub enum Error {
53 #[error("failed to read data from read stream")]
55 ReaderReadFailed(#[source] std::io::Error),
56
57 #[error("received invalid bson document size header value of {0}, must be at least 4")]
59 BsonDocumentSizeTooSmall(i32),
60
61 #[error("received invalid bson document size header value of {0}, must be less than {1}")]
63 BsonDocumentSizeTooLarge(i32, i32),
64
65 #[error("waited longer than {} seconds for read", .0.as_secs_f32())]
67 MessageReadTimedOut(std::time::Duration),
68
69 #[error("failed to parse bson Message document")]
71 BsonDocumentParseFailed(#[source] bson::de::Error),
72
73 #[error("failed to convert bson document to Message")]
75 MessageConversionFailed(#[source] crate::honk_rpc::ErrorCode),
76
77 #[error("failed to serialize bson document")]
79 BsonWriteFailed(#[source] bson::ser::Error),
80
81 #[error("failed to write data to write stream")]
83 WriterWriteFailed(#[source] std::io::Error),
84
85 #[error("failed to flush message to write stream")]
87 WriterFlushFailed(#[source] std::io::Error),
88
89 #[error("recieved error section without cookie")]
91 UnknownErrorSectionReceived(#[source] crate::honk_rpc::ErrorCode),
92
93 #[error(
95 "tried to set invalid max message size; must be >=5 bytes and <= i32::MAX (2147483647)"
96 )]
97 InvalidMaxMesageSize(),
98
99 #[error("queued message section is too large to write; calculated size is {0} but must be less than {1}")]
101 SectionTooLarge(usize, usize),
102}
103
104impl From<i32> for ErrorCode {
105 fn from(value: i32) -> ErrorCode {
106 match value {
107 -1i32 => ErrorCode::BsonParseFailed,
108 -2i32 => ErrorCode::MessageTooBig,
109 -3i32 => ErrorCode::MessageParseFailed,
110 -4i32 => ErrorCode::MessageVersionIncompatible,
111 -5i32 => ErrorCode::SectionIdUnknown,
112 -6i32 => ErrorCode::SectionParseFailed,
113 -7i32 => ErrorCode::RequestCookieInvalid,
114 -8i32 => ErrorCode::RequestNamespaceInvalid,
115 -9i32 => ErrorCode::RequestFunctionInvalid,
116 -10i32 => ErrorCode::RequestVersionInvalid,
117 -11i32 => ErrorCode::ResponseCookieInvalid,
118 -12i32 => ErrorCode::ResponseStateInvalid,
119 value => {
120 if value > 0 {
121 ErrorCode::Runtime(value)
122 } else {
123 ErrorCode::Unknown(value)
124 }
125 }
126 }
127 }
128}
129
130impl From<ErrorCode> for i32 {
131 fn from(err: ErrorCode) -> Self {
132 match err {
133 ErrorCode::BsonParseFailed => -1i32,
134 ErrorCode::MessageTooBig => -2i32,
135 ErrorCode::MessageParseFailed => -3i32,
136 ErrorCode::MessageVersionIncompatible => -4i32,
137 ErrorCode::SectionIdUnknown => -5i32,
138 ErrorCode::SectionParseFailed => -6i32,
139 ErrorCode::RequestCookieInvalid => -7i32,
140 ErrorCode::RequestNamespaceInvalid => -8i32,
141 ErrorCode::RequestFunctionInvalid => -9i32,
142 ErrorCode::RequestVersionInvalid => -10i32,
143 ErrorCode::ResponseCookieInvalid => -11i32,
144 ErrorCode::ResponseStateInvalid => -12i32,
145 ErrorCode::Runtime(val) => val,
146 ErrorCode::Unknown(val) => val,
147 }
148 }
149}
150
151impl std::fmt::Display for ErrorCode {
152 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
153 match self {
154 ErrorCode::BsonParseFailed => write!(f, "ProtocolError: failed to parse BSON object"),
155 ErrorCode::MessageTooBig => write!(f, "ProtocolError: received document too large"),
156 ErrorCode::MessageParseFailed => {
157 write!(f, "ProtocolError: received message has invalid schema")
158 }
159 ErrorCode::MessageVersionIncompatible => write!(
160 f,
161 "ProtocolError: received message has incompatible version"
162 ),
163 ErrorCode::SectionIdUnknown => write!(
164 f,
165 "ProtocolError: received message contains section of unknown type"
166 ),
167 ErrorCode::SectionParseFailed => write!(
168 f,
169 "ProtocolError: received message contains section with invalid schema"
170 ),
171 ErrorCode::RequestCookieInvalid => {
172 write!(f, "ProtocolError: request cookie already in use")
173 }
174 ErrorCode::RequestNamespaceInvalid => write!(
175 f,
176 "ProtocolError: request function does not exist in requested namespace"
177 ),
178 ErrorCode::RequestFunctionInvalid => {
179 write!(f, "ProtocolError: request function does not exist")
180 }
181 ErrorCode::RequestVersionInvalid => {
182 write!(f, "ProtocolError: request function version does not exist")
183 }
184 ErrorCode::ResponseCookieInvalid => {
185 write!(f, "ProtocolError: response cookie is not recognized")
186 }
187 ErrorCode::ResponseStateInvalid => write!(f, "ProtocolError: response state not valid"),
188 ErrorCode::Runtime(code) => write!(f, "RuntimeError: runtime error {}", code),
189 ErrorCode::Unknown(code) => write!(f, "UnknownError: unknown error code {}", code),
190 }
191 }
192}
193
194impl std::error::Error for ErrorCode {}
195
196const fn semver_to_i32(major: u8, minor: u8, patch: u8) -> i32 {
198 let major = major as i32;
199 let minor = minor as i32;
200 let patch = patch as i32;
201 (major << 16) | (minor << 8) | patch
202}
203
204const fn i32_to_semver(ver: i32) -> Option<(u8, u8, u8)> {
205 if ver >= 0 && ver <= 0xffffff {
206 let major = (ver & 0xff0000) >> 16;
207 let minor = (ver & 0xff00) >> 8;
208 let patch = ver & 0xff;
209 Some((major as u8, minor as u8, patch as u8))
210 } else {
211 None
212 }
213}
214
215const HONK_RPC_VERSION: i32 = semver_to_i32(0, 1, 0);
217
218struct Message {
219 honk_rpc: i32,
220 sections: Vec<Section>,
221}
222
223impl TryFrom<bson::document::Document> for Message {
224 type Error = ErrorCode;
225
226 fn try_from(value: bson::document::Document) -> Result<Self, Self::Error> {
227 let mut value = value;
228
229 let honk_rpc = match value.get_i32("honk_rpc") {
231 Ok(HONK_RPC_VERSION) => HONK_RPC_VERSION,
232 Ok(honk_rpc) => {
233 return if let Some(_version) = i32_to_semver(honk_rpc) {
234 Err(ErrorCode::MessageVersionIncompatible)
236 } else {
237 Err(ErrorCode::MessageParseFailed)
239 };
240 }
241 Err(_err) => return Err(ErrorCode::MessageParseFailed),
242 };
243
244 if let Ok(sections) = value.get_array_mut("sections") {
245 if sections.is_empty() {
247 return Err(ErrorCode::MessageParseFailed);
248 }
249
250 let mut message = Message {
251 honk_rpc,
252 sections: Default::default(),
253 };
254
255 for section in sections.iter_mut() {
256 if let bson::Bson::Document(section) = std::mem::take(section) {
257 message.sections.push(Section::try_from(section)?);
258 } else {
259 return Err(ErrorCode::SectionParseFailed);
260 }
261 }
262 Ok(message)
263 } else {
264 Err(ErrorCode::MessageParseFailed)
265 }
266 }
267}
268
269impl From<Message> for bson::document::Document {
270 fn from(value: Message) -> bson::document::Document {
271 let mut value = value;
272 let mut message = bson::document::Document::new();
273 message.insert("honk_rpc", value.honk_rpc);
274
275 let mut sections = bson::Array::new();
276 for section in value.sections.drain(0..) {
277 sections.push(bson::Bson::Document(bson::document::Document::from(
278 section,
279 )));
280 }
281 message.insert("sections", sections);
282
283 message
284 }
285}
286
287pub type RequestCookie = i64;
289
290const ERROR_SECTION_ID: i32 = 0i32;
291const REQUEST_SECTION_ID: i32 = 1i32;
292const RESPONSE_SECTION_ID: i32 = 2i32;
293
294enum Section {
295 Error(ErrorSection),
296 Request(RequestSection),
297 Response(ResponseSection),
298}
299
300struct ErrorSection {
301 cookie: Option<RequestCookie>,
302 code: ErrorCode,
303 message: Option<String>,
304 data: Option<bson::Bson>,
305}
306
307struct RequestSection {
308 cookie: Option<RequestCookie>,
309 namespace: String,
310 function: String,
311 version: i32,
312 arguments: bson::document::Document,
313}
314
315#[repr(i32)]
316#[derive(Debug, PartialEq)]
317enum RequestState {
318 Pending = 0i32,
319 Complete = 1i32,
320}
321
322struct ResponseSection {
323 cookie: RequestCookie,
324 state: RequestState,
325 result: Option<bson::Bson>,
326}
327
328impl TryFrom<bson::document::Document> for Section {
329 type Error = ErrorCode;
330
331 fn try_from(
332 value: bson::document::Document,
333 ) -> Result<Self, <Self as TryFrom<bson::document::Document>>::Error> {
334 match value.get_i32("id") {
335 Ok(ERROR_SECTION_ID) => Ok(Section::Error(ErrorSection::try_from(value)?)),
336 Ok(REQUEST_SECTION_ID) => Ok(Section::Request(RequestSection::try_from(value)?)),
337 Ok(RESPONSE_SECTION_ID) => Ok(Section::Response(ResponseSection::try_from(value)?)),
338 Ok(_) => Err(ErrorCode::SectionIdUnknown),
339 Err(_) => Err(ErrorCode::SectionParseFailed),
340 }
341 }
342}
343
344impl From<Section> for bson::document::Document {
345 fn from(value: Section) -> bson::document::Document {
346 match value {
347 Section::Error(section) => bson::document::Document::from(section),
348 Section::Request(section) => bson::document::Document::from(section),
349 Section::Response(section) => bson::document::Document::from(section),
350 }
351 }
352}
353
354impl TryFrom<bson::document::Document> for ErrorSection {
355 type Error = ErrorCode;
356
357 fn try_from(value: bson::document::Document) -> Result<Self, Self::Error> {
358 let mut value = value;
359
360 let cookie = match value.get_i64("cookie") {
361 Ok(cookie) => Some(cookie),
362 Err(ValueAccessError::NotPresent) => None,
363 Err(_) => return Err(ErrorCode::SectionParseFailed),
364 };
365
366 let code = match value.get_i32("code") {
367 Ok(code) => ErrorCode::from(code),
368 Err(_) => return Err(ErrorCode::SectionParseFailed),
369 };
370
371 let message = match value.get_str("message") {
372 Ok(message) => Some(message.to_string()),
373 Err(ValueAccessError::NotPresent) => None,
374 Err(_) => return Err(ErrorCode::SectionParseFailed),
375 };
376
377 let data = value.get_mut("data").map(std::mem::take);
378
379 Ok(ErrorSection {
380 cookie,
381 code,
382 message,
383 data,
384 })
385 }
386}
387
388impl From<ErrorSection> for bson::document::Document {
389 fn from(value: ErrorSection) -> bson::document::Document {
390 let mut error_section = bson::document::Document::new();
391 error_section.insert("id", ERROR_SECTION_ID);
392
393 if let Some(cookie) = value.cookie {
394 error_section.insert("cookie", cookie);
395 }
396
397 error_section.insert("code", Into::<i32>::into(value.code));
398
399 if let Some(message) = value.message {
400 error_section.insert("message", message);
401 }
402
403 if let Some(data) = value.data {
404 error_section.insert("data", data);
405 }
406
407 error_section
408 }
409}
410
411impl TryFrom<bson::document::Document> for RequestSection {
412 type Error = ErrorCode;
413
414 fn try_from(value: bson::document::Document) -> Result<Self, Self::Error> {
415 let mut value = value;
416
417 let cookie = match value.get_i64("cookie") {
418 Ok(cookie) => Some(cookie),
419 Err(ValueAccessError::NotPresent) => None,
420 Err(_) => return Err(ErrorCode::SectionParseFailed),
421 };
422
423 let namespace = match value.get_str("namespace") {
424 Ok(namespace) => namespace.to_string(),
425 Err(ValueAccessError::NotPresent) => String::default(),
426 Err(_) => return Err(ErrorCode::SectionParseFailed),
427 };
428
429 let function = match value.get_str("function") {
430 Ok(function) => {
431 if function.is_empty() {
432 return Err(ErrorCode::RequestFunctionInvalid);
433 } else {
434 function.to_string()
435 }
436 }
437 Err(_) => return Err(ErrorCode::SectionParseFailed),
438 };
439
440 let version = match value.get_i32("version") {
441 Ok(version) => version,
442 Err(ValueAccessError::NotPresent) => 0i32,
443 Err(_) => return Err(ErrorCode::SectionParseFailed),
444 };
445
446 let arguments = match value.get_document_mut("arguments") {
447 Ok(arguments) => std::mem::take(arguments),
448 Err(ValueAccessError::NotPresent) => bson::document::Document::new(),
449 Err(_) => return Err(ErrorCode::SectionParseFailed),
450 };
451
452 Ok(RequestSection {
453 cookie,
454 namespace,
455 function,
456 version,
457 arguments,
458 })
459 }
460}
461
462impl From<RequestSection> for bson::document::Document {
463 fn from(value: RequestSection) -> bson::document::Document {
464 let mut request_section = bson::document::Document::new();
465 request_section.insert("id", REQUEST_SECTION_ID);
466
467 if let Some(cookie) = value.cookie {
468 request_section.insert("cookie", cookie);
469 }
470
471 if !value.namespace.is_empty() {
472 request_section.insert("namespace", value.namespace);
473 }
474
475 request_section.insert("function", value.function);
476
477 if value.version != 0i32 {
478 request_section.insert("version", value.version);
479 }
480
481 request_section.insert("arguments", value.arguments);
482
483 request_section
484 }
485}
486
487impl TryFrom<bson::document::Document> for ResponseSection {
488 type Error = ErrorCode;
489
490 fn try_from(value: bson::document::Document) -> Result<Self, Self::Error> {
491 let mut value = value;
492 let cookie = match value.get_i64("cookie") {
493 Ok(cookie) => cookie,
494 Err(_) => return Err(ErrorCode::SectionParseFailed),
495 };
496
497 let state = match value.get_i32("state") {
498 Ok(0i32) => RequestState::Pending,
499 Ok(1i32) => RequestState::Complete,
500 Ok(_) => return Err(ErrorCode::ResponseStateInvalid),
501 Err(_) => return Err(ErrorCode::SectionParseFailed),
502 };
503
504 let result = value.get_mut("result").map(std::mem::take);
505
506 if state == RequestState::Pending && result.is_some() {
508 return Err(ErrorCode::SectionParseFailed);
509 }
510
511 Ok(ResponseSection {
512 cookie,
513 state,
514 result,
515 })
516 }
517}
518
519impl From<ResponseSection> for bson::document::Document {
520 fn from(value: ResponseSection) -> bson::document::Document {
521 let mut response_section = bson::document::Document::new();
522 response_section.insert("id", RESPONSE_SECTION_ID);
523
524 response_section.insert("cookie", value.cookie);
525 response_section.insert("state", value.state as i32);
526
527 if let Some(result) = value.result {
528 response_section.insert("result", result);
529 }
530
531 response_section
532 }
533}
534
535pub trait ApiSet {
628 fn namespace(&self) -> &str;
630
631 fn exec_function(
645 &mut self,
646 name: &str,
647 version: i32,
648 args: bson::document::Document,
649 request_cookie: Option<RequestCookie>,
650 ) -> Option<Result<Option<bson::Bson>, ErrorCode>>;
651
652 fn update(&mut self) {}
657
658 fn next_result(&mut self) -> Option<(RequestCookie, Result<Option<bson::Bson>, ErrorCode>)> {
666 None
667 }
668}
669
670pub enum Response {
672 Pending {
674 cookie: RequestCookie,
676 },
677 Success {
679 cookie: RequestCookie,
681 result: Option<bson::Bson>,
683 },
684 Error {
686 cookie: RequestCookie,
688 error_code: ErrorCode,
690 },
691}
692
693pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 4 * 1024;
696pub const DEFAULT_MAX_WAIT_TIME: std::time::Duration = std::time::Duration::from_secs(60);
698
699const HEADER_SIZE: usize = 4usize;
702const HONK_RPC_SIZE: usize = 14usize;
704const SECTIONS_SIZE: usize = 18usize;
706const FOOTER_SIZE: usize = 1usize;
708
709const MIN_MESSAGE_SIZE: usize = HEADER_SIZE + HONK_RPC_SIZE + SECTIONS_SIZE + FOOTER_SIZE;
711
712pub fn get_message_overhead() -> Result<usize, Error> {
716 let message = doc! {
719 "honk_rpc" : HONK_RPC_VERSION,
720 "sections" : [
721 bson::Bson::Null
722 ]
723 };
724
725 let mut counter: ByteCounter = Default::default();
726 message
727 .to_writer(&mut counter)
728 .map_err(Error::BsonWriteFailed)?;
729
730 Ok(counter.bytes())
731}
732
733pub fn get_error_section_size(
738 cookie: Option<RequestCookie>,
739 message: Option<String>,
740 data: Option<bson::Bson>,
741) -> Result<usize, Error> {
742 let mut error_section = doc! {
743 "id": ERROR_SECTION_ID,
744 "code": Into::<i32>::into(ErrorCode::Unknown(0)),
745 };
746
747 if let Some(cookie) = cookie {
748 error_section.insert("cookie", bson::Bson::Int64(cookie));
749 }
750
751 if let Some(message) = message {
752 error_section.insert("message", bson::Bson::String(message));
753 }
754
755 if let Some(data) = data {
756 error_section.insert("data", data);
757 }
758
759 let mut counter: ByteCounter = Default::default();
760 error_section
761 .to_writer(&mut counter)
762 .map_err(Error::BsonWriteFailed)?;
763
764 Ok(counter.bytes())
765}
766
767pub fn get_request_section_size(
772 cookie: Option<RequestCookie>,
773 namespace: Option<String>,
774 function: String,
775 version: Option<i32>,
776 arguments: Option<bson::Document>,
777) -> Result<usize, Error> {
778 let mut request_section = doc! {
779 "id": REQUEST_SECTION_ID,
780 "function": bson::Bson::String(function),
781 };
782
783 if let Some(cookie) = cookie {
784 request_section.insert("cookie", bson::Bson::Int64(cookie));
785 }
786
787 if let Some(namespace) = namespace {
788 request_section.insert("namespace", bson::Bson::String(namespace));
789 }
790
791 if let Some(version) = version {
792 request_section.insert("version", bson::Bson::Int32(version));
793 }
794
795 if let Some(arguments) = arguments {
796 request_section.insert("arguments", arguments);
797 }
798
799 let mut counter: ByteCounter = Default::default();
800 request_section
801 .to_writer(&mut counter)
802 .map_err(Error::BsonWriteFailed)?;
803
804 Ok(counter.bytes())
805}
806
807pub fn get_response_section_size(result: Option<bson::Bson>) -> Result<usize, Error> {
812 let mut response_section = doc! {
813 "id": RESPONSE_SECTION_ID,
814 "cookie": bson::Bson::Int64(0),
815 "state": bson::Bson::Int32(0),
816 };
817
818 if let Some(result) = result {
819 response_section.insert("result", result);
820 }
821
822 let mut counter: ByteCounter = Default::default();
823 response_section
824 .to_writer(&mut counter)
825 .map_err(Error::BsonWriteFailed)?;
826
827 Ok(counter.bytes())
828}
829
830pub struct Session<RW> {
835 stream: RW,
837 message_write_buffer: VecDeque<u8>,
839
840 remaining_byte_count: Option<usize>,
845 message_read_buffer: Vec<u8>,
847 pending_sections: VecDeque<Section>,
849 inbound_requests: Vec<RequestSection>,
851 inbound_responses: VecDeque<Response>,
853
854 message_serialization_buffer: VecDeque<u8>,
858 next_cookie: RequestCookie,
860 outbound_sections: Vec<bson::Document>,
862
863 max_message_size: usize,
865 max_wait_time: std::time::Duration,
868 read_timestamp: std::time::Instant,
870}
871
872#[allow(dead_code)]
873impl<RW> Session<RW>
874where
875 RW: std::io::Read + std::io::Write + Send,
876{
877 pub fn set_max_message_size(&mut self, max_message_size: i32) -> Result<(), Error> {
879 if max_message_size < MIN_MESSAGE_SIZE as i32 {
880 Err(Error::InvalidMaxMesageSize())
882 } else {
883 self.max_message_size = max_message_size as usize;
884 Ok(())
885 }
886 }
887
888 pub fn get_max_message_size(&self) -> usize {
890 self.max_message_size
891 }
892
893 pub fn set_max_wait_time(&mut self, max_wait_time: std::time::Duration) {
895 self.max_wait_time = max_wait_time;
896 }
897
898 pub fn get_max_wait_time(&self) -> std::time::Duration {
900 self.max_wait_time
901 }
902
903 pub fn new(stream: RW) -> Self {
905 let mut message_write_buffer: VecDeque<u8> = Default::default();
906 message_write_buffer.reserve(DEFAULT_MAX_MESSAGE_SIZE);
907
908 let mut message_serialization_buffer: VecDeque<u8> = Default::default();
909 message_serialization_buffer.reserve(DEFAULT_MAX_MESSAGE_SIZE);
910
911 Session {
912 stream,
913 message_write_buffer,
914 remaining_byte_count: None,
915 message_read_buffer: Default::default(),
916 pending_sections: Default::default(),
917 inbound_requests: Default::default(),
918 inbound_responses: Default::default(),
919 message_serialization_buffer,
920 next_cookie: Default::default(),
921 outbound_sections: Default::default(),
922 max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
923 max_wait_time: DEFAULT_MAX_WAIT_TIME,
924 read_timestamp: std::time::Instant::now(),
925 }
926 }
927
928 pub fn into_stream(self) -> RW {
930 self.stream
931 }
932
933 fn stream_read(&mut self, buffer: &mut [u8]) -> Result<usize, Error> {
935 match self.stream.read(buffer) {
936 Err(err) => {
937 if err.kind() == ErrorKind::WouldBlock || err.kind() == ErrorKind::TimedOut {
938 if std::time::Instant::now().duration_since(self.read_timestamp)
940 > self.max_wait_time
941 {
942 Err(Error::MessageReadTimedOut(self.max_wait_time))
943 } else {
944 Ok(0)
945 }
946 } else {
947 Err(Error::ReaderReadFailed(err))
948 }
949 }
950 Ok(0) => Err(Error::ReaderReadFailed(std::io::Error::from(
951 ErrorKind::UnexpectedEof,
952 ))),
953 Ok(count) => {
954 self.read_timestamp = std::time::Instant::now();
956 Ok(count)
957 }
958 }
959 }
960
961 fn read_message_size(&mut self) -> Result<(), Error> {
963 match self.remaining_byte_count {
964 Some(_remaining) => Ok(()),
966 None => {
968 assert!(self.message_read_buffer.len() < std::mem::size_of::<i32>());
970 let bytes_needed = std::mem::size_of::<i32>() - self.message_read_buffer.len();
971 let mut buffer = [0u8; std::mem::size_of::<i32>()];
973 let buffer = &mut buffer[0..bytes_needed];
975 match self.stream_read(buffer) {
976 Err(err) => Err(err),
977 Ok(0) => Ok(()),
978 Ok(count) => {
979 #[cfg(test)]
980 println!("<<< read {} bytes for message header", count);
981 self.message_read_buffer
982 .extend_from_slice(&buffer[0..count]);
983
984 if self.message_read_buffer.len() == std::mem::size_of::<i32>() {
986 let size = &self.message_read_buffer.as_slice();
987 let size: i32 = (size[0] as i32)
988 | (size[1] as i32) << 8
989 | (size[2] as i32) << 16
990 | (size[3] as i32) << 24;
991 if size <= std::mem::size_of::<i32>() as i32 {
993 return Err(Error::BsonDocumentSizeTooSmall(size));
994 }
995 if size as usize > self.max_message_size {
997 return Err(Error::BsonDocumentSizeTooLarge(
998 size,
999 self.max_message_size as i32,
1000 ));
1001 }
1002
1003 let size = size as usize - std::mem::size_of::<i32>();
1005
1006 self.remaining_byte_count = Some(size);
1007 }
1008 Ok(())
1009 }
1010 }
1011 }
1012 }
1013 }
1014
1015 fn read_message(&mut self) -> Result<Option<Message>, Error> {
1017 self.read_message_size()?;
1019 if let Some(remaining) = self.remaining_byte_count {
1021 #[cfg(test)]
1022 println!("--- message requires {} more bytes", remaining);
1023
1024 let mut buffer = vec![0u8; remaining];
1025 match self.stream_read(&mut buffer) {
1026 Err(err) => Err(err),
1027 Ok(0) => Ok(None),
1028 Ok(count) => {
1029 #[cfg(test)]
1030 println!("<<< read {} bytes", count);
1031 self.message_read_buffer
1033 .extend_from_slice(&buffer[0..count]);
1034 if remaining == count {
1035 self.remaining_byte_count = None;
1036
1037 let mut cursor = Cursor::new(std::mem::take(&mut self.message_read_buffer));
1038 let bson = bson::document::Document::from_reader(&mut cursor)
1039 .map_err(Error::BsonDocumentParseFailed)?;
1040
1041 self.message_read_buffer = cursor.into_inner();
1043 self.message_read_buffer.clear();
1044
1045 #[cfg(test)]
1046 println!("<<< read message: {}", bson);
1047
1048 Ok(Some(
1049 Message::try_from(bson).map_err(Error::MessageConversionFailed)?,
1050 ))
1051 } else {
1052 self.remaining_byte_count = Some(remaining - count);
1054 Ok(None)
1055 }
1056 }
1057 }
1058 } else {
1059 Ok(None)
1060 }
1061 }
1062
1063 fn read_sections(&mut self) -> Result<(), Error> {
1065 loop {
1066 match self.read_message() {
1067 Ok(Some(mut message)) => {
1068 self.pending_sections.extend(message.sections.drain(..));
1069 }
1070 Ok(None) => return Ok(()),
1071 Err(err) => {
1072 match err {
1073 Error::MessageReadTimedOut(_) | Error::ReaderReadFailed(_) => {
1075 if self.pending_sections.is_empty() && self.inbound_responses.is_empty()
1077 {
1078 return Err(err);
1079 }
1080 return Ok(());
1081 }
1082 _ => return Err(err),
1084 }
1085 }
1086 }
1087 }
1088 }
1089
1090 fn process_sections(&mut self) -> Result<(), Error> {
1092 while let Some(section) = self.pending_sections.pop_front() {
1093 match section {
1094 Section::Error(error) => {
1095 if let Some(cookie) = error.cookie {
1096 self.inbound_responses.push_back(Response::Error {
1098 cookie,
1099 error_code: error.code,
1100 });
1101 } else {
1102 return Err(Error::UnknownErrorSectionReceived(error.code));
1103 }
1104 }
1105 Section::Request(request) => {
1106 self.inbound_requests.push(request);
1108 }
1109 Section::Response(response) => {
1110 match (response.cookie, response.state, response.result) {
1113 (cookie, RequestState::Complete, result) => {
1114 self.inbound_responses
1115 .push_back(Response::Success { cookie, result });
1116 }
1117 (cookie, RequestState::Pending, _) => {
1118 self.inbound_responses
1119 .push_back(Response::Pending { cookie });
1120 }
1121 }
1122 }
1123 }
1124 }
1125
1126 Ok(())
1127 }
1128
1129 fn push_outbound_section(&mut self, section: Section) -> Result<(), Error> {
1131 let max_section_size = self.max_message_size - MIN_MESSAGE_SIZE;
1132
1133 let mut counter: ByteCounter = Default::default();
1134 let section: bson::Document = section.into();
1135 section
1136 .to_writer(&mut counter)
1137 .map_err(Error::BsonWriteFailed)?;
1138 let section_size = counter.bytes();
1139
1140 if section_size <= max_section_size {
1141 self.outbound_sections.push(section);
1142 Ok(())
1143 } else {
1144 Err(Error::SectionTooLarge(section_size, max_section_size))
1145 }
1146 }
1147
1148 fn serialize_messages(&mut self) -> Result<(), Error> {
1150 if self.outbound_sections.is_empty() {
1152 return Ok(());
1153 }
1154
1155 let message = Message {
1157 honk_rpc: HONK_RPC_VERSION,
1158 sections: Default::default(),
1159 };
1160 let mut message = bson::document::Document::from(message);
1161 message.insert("sections", std::mem::take(&mut self.outbound_sections));
1162 self.serialize_messages_impl(message)
1163 }
1164
1165 fn serialize_messages_impl(
1167 &mut self,
1168 mut message: bson::document::Document,
1169 ) -> Result<(), Error> {
1170 self.message_serialization_buffer.clear();
1171 message
1172 .to_writer(&mut self.message_serialization_buffer)
1173 .map_err(Error::BsonWriteFailed)?;
1174
1175 if self.message_serialization_buffer.len() > self.max_message_size {
1176 let sections = message.get_array_mut("sections").unwrap();
1178 assert!(sections.len() > 1);
1179
1180 let right = doc! {
1181 "honk_rpc" : HONK_RPC_VERSION,
1182 "sections" : sections.split_off(sections.len() / 2),
1183 };
1184 let left = message;
1185
1186 self.serialize_messages_impl(left)?;
1187 self.serialize_messages_impl(right)?;
1188 } else {
1189 #[cfg(test)]
1190 println!(">>> write message: {:?}", message);
1191 self.message_write_buffer
1193 .append(&mut self.message_serialization_buffer);
1194 }
1195
1196 Ok(())
1197 }
1198
1199 fn write_pending_data(&mut self) -> Result<(), Error> {
1201 let bytes_written = self.write_pending_data_impl()?;
1202 self.stream.flush().map_err(Error::WriterWriteFailed)?;
1203 self.message_write_buffer.drain(0..bytes_written);
1205 self.message_write_buffer.make_contiguous();
1207
1208 Ok(())
1209 }
1210
1211 fn write_pending_data_impl(&mut self) -> Result<usize, Error> {
1212 let (mut pending_data, empty): (&[u8], &[u8]) = self.message_write_buffer.as_slices();
1214 assert!(empty.is_empty());
1215 let pending_bytes: usize = pending_data.len();
1216 let mut bytes_written: usize = 0usize;
1217
1218 while bytes_written != pending_bytes {
1219 match self.stream.write(pending_data) {
1220 Err(err) => {
1221 let kind = err.kind();
1222 if kind == ErrorKind::WouldBlock || kind == ErrorKind::TimedOut {
1223 return Ok(bytes_written);
1225 } else {
1226 return Err(Error::WriterWriteFailed(err));
1227 }
1228 }
1229 Ok(count) => {
1230 bytes_written += count;
1231 #[cfg(test)]
1232 println!(">>> sent {} of {} bytes", bytes_written, pending_bytes);
1233 pending_data = &pending_data[count..];
1234 }
1235 }
1236 }
1237
1238 Ok(bytes_written)
1239 }
1240
1241 pub fn update(&mut self, apisets: Option<&mut [&mut dyn ApiSet]>) -> Result<(), Error> {
1243 self.read_sections()?;
1245 self.process_sections()?;
1247
1248 let apisets = apisets.unwrap_or(&mut []);
1250 self.handle_requests(apisets)?;
1251
1252 self.serialize_messages()?;
1254
1255 self.write_pending_data()?;
1257
1258 Ok(())
1259 }
1260
1261 fn handle_requests(&mut self, apisets: &mut [&mut dyn ApiSet]) -> Result<(), Error> {
1263 let mut inbound_requests = std::mem::take(&mut self.inbound_requests);
1265 for mut request in inbound_requests.drain(..) {
1266 if let Ok(idx) =
1267 apisets.binary_search_by(|probe| probe.namespace().cmp(&request.namespace))
1268 {
1269 let apiset = match apisets.get_mut(idx) {
1270 Some(apiset) => apiset,
1271 None => unreachable!(),
1272 };
1273 match apiset.exec_function(
1274 &request.function,
1275 request.version,
1276 std::mem::take(&mut request.arguments),
1277 request.cookie,
1278 ) {
1279 Some(Ok(result)) => {
1281 if let Some(cookie) = request.cookie {
1282 self.push_outbound_section(Section::Response(ResponseSection {
1283 cookie,
1284 state: RequestState::Complete,
1285 result,
1286 }))?;
1287 }
1288 }
1289 Some(Err(error_code)) => {
1291 self.push_outbound_section(Section::Error(ErrorSection {
1292 cookie: request.cookie,
1293 code: error_code,
1294 message: None,
1295 data: None,
1296 }))?;
1297 }
1298 None => {
1300 if let Some(cookie) = request.cookie {
1301 self.push_outbound_section(Section::Response(ResponseSection {
1302 cookie,
1303 state: RequestState::Pending,
1304 result: None,
1305 }))?;
1306 }
1307 }
1308 }
1309 } else {
1310 self.push_outbound_section(Section::Error(ErrorSection {
1312 cookie: request.cookie,
1313 code: ErrorCode::RequestNamespaceInvalid,
1314 message: None,
1315 data: None,
1316 }))?;
1317 }
1318 }
1319
1320 for apiset in apisets.iter_mut() {
1322 apiset.update();
1324 while let Some((cookie, result)) = apiset.next_result() {
1326 match (cookie, result) {
1327 (cookie, Ok(result)) => {
1329 self.push_outbound_section(Section::Response(ResponseSection {
1330 cookie,
1331 state: RequestState::Complete,
1332 result,
1333 }))?;
1334 }
1335 (cookie, Err(error_code)) => {
1337 self.push_outbound_section(Section::Error(ErrorSection {
1338 cookie: Some(cookie),
1339 code: error_code,
1340 message: None,
1341 data: None,
1342 }))?;
1343 }
1344 }
1345 }
1346 }
1347 Ok(())
1348 }
1349
1350 pub fn client_call(
1352 &mut self,
1353 namespace: &str,
1354 function: &str,
1355 version: i32,
1356 arguments: bson::document::Document,
1357 ) -> Result<RequestCookie, Error> {
1358 let cookie = self.next_cookie;
1360 self.next_cookie += 1;
1361
1362 self.push_outbound_section(Section::Request(RequestSection {
1364 cookie: Some(cookie),
1365 namespace: namespace.to_string(),
1366 function: function.to_string(),
1367 version,
1368 arguments,
1369 }))?;
1370
1371 Ok(cookie)
1372 }
1373
1374 pub fn client_drain_responses(&mut self) -> std::collections::vec_deque::Drain<Response> {
1376 self.inbound_responses.drain(..)
1377 }
1378
1379 pub fn client_next_response(&mut self) -> Option<Response> {
1381 self.inbound_responses.pop_front()
1382 }
1383}
1384
1385#[test]
1386fn test_honk_client_read_write() -> anyhow::Result<()> {
1387 let socket_addr = SocketAddr::from(([127, 0, 0, 1], 0u16));
1388 let listener = TcpListener::bind(socket_addr)?;
1389 let socket_addr = listener.local_addr()?;
1390
1391 let stream1 = TcpStream::connect(socket_addr)?;
1392 stream1.set_nonblocking(true)?;
1393 let (stream2, _socket_addr) = listener.accept()?;
1394 stream2.set_nonblocking(true)?;
1395
1396 let mut alice = Session::new(stream1);
1397 let mut pat = Session::new(stream2);
1398
1399 println!("--- pat reads message, but none has been sent");
1400
1401 assert!(pat.read_message()?.is_none());
1403
1404 println!("--- alice sends no message, but no pending sections so no message sent");
1405
1406 alice.serialize_messages()?;
1408 alice.write_pending_data()?;
1409
1410 println!("--- pat reads message, but none has been sent");
1411
1412 match pat.read_message() {
1414 Ok(Some(msg)) => panic!(
1415 "message should not have been sent: {}",
1416 bson::document::Document::from(msg)
1417 ),
1418 Ok(None) => {}
1419 Err(err) => panic!("{:?}", err),
1420 }
1421
1422 println!("--- pat sends an error message");
1423
1424 const CUSTOM_ERROR: &str = "Custom Error!";
1425
1426 pat.push_outbound_section(Section::Error(ErrorSection {
1427 cookie: Some(42069),
1428 code: ErrorCode::Runtime(1),
1429 message: Some(CUSTOM_ERROR.to_string()),
1430 data: None,
1431 }))?;
1432
1433 pat.serialize_messages()?;
1434 pat.write_pending_data()?;
1435
1436 println!("--- alice reads and verifies message");
1437
1438 let mut alice_read_message: bool = false;
1440 while !alice_read_message {
1441 if let Some(mut msg) = alice.read_message()? {
1443 assert_eq!(msg.sections.len(), 1);
1444 match msg.sections.pop() {
1445 Some(Section::Error(section)) => {
1446 match (section.cookie, section.code, section.message) {
1447 (Some(42069), ErrorCode::Runtime(1), Some(message)) => {
1448 assert_eq!(message, CUSTOM_ERROR);
1449 alice_read_message = true;
1450 }
1451 (cookie, code, message) => panic!(
1452 "unexpected error section: cookie: {:?}, code: {:?}, message: {:?}",
1453 cookie, code, message
1454 ),
1455 };
1456 }
1457 Some(_) => panic!("was expecting an Error section"),
1458 None => panic!("we should have a message"),
1459 }
1460 }
1461 }
1462
1463 println!("--- alice sends multi-section message");
1464
1465 alice.push_outbound_section(Section::Error(ErrorSection {
1466 cookie: Some(42069),
1467 code: ErrorCode::Runtime(2),
1468 message: Some(CUSTOM_ERROR.to_string()),
1469 data: None,
1470 }))?;
1471 alice.push_outbound_section(Section::Request(RequestSection {
1472 cookie: None,
1473 namespace: "std".to_string(),
1474 function: "print".to_string(),
1475 version: 0,
1476 arguments: doc! {"message": "hello!"},
1477 }))?;
1478 alice.push_outbound_section(Section::Response(ResponseSection {
1479 cookie: 123456,
1480 state: RequestState::Pending,
1481 result: None,
1482 }))?;
1483
1484 alice.serialize_messages()?;
1486 alice.write_pending_data()?;
1487
1488 println!("--- pat reads and verifies multi-section message");
1489
1490 let mut pat_read_message: bool = false;
1492 while !pat_read_message {
1493 if let Some(msg) = pat.read_message()? {
1494 assert_eq!(msg.sections.len(), 3);
1495 for section in msg.sections.iter() {
1496 match section {
1497 Section::Error(section) => {
1498 assert_eq!(section.cookie, Some(42069));
1499 assert_eq!(section.code, ErrorCode::Runtime(2));
1500 assert_eq!(section.message, Some(CUSTOM_ERROR.to_string()));
1501 assert_eq!(section.data, None);
1502 }
1503 Section::Request(section) => {
1504 assert_eq!(section.cookie, None);
1505 assert_eq!(section.namespace, "std");
1506 assert_eq!(section.function, "print");
1507 assert_eq!(section.version, 0i32);
1508 }
1509 Section::Response(section) => {
1510 assert_eq!(section.cookie, 123456);
1511 assert_eq!(section.state, RequestState::Pending);
1512 assert_eq!(section.result, None);
1513 }
1514 }
1515 }
1516 pat_read_message = true;
1517 }
1518 }
1519
1520 Ok(())
1521}
1522
1523#[cfg(test)]
1524struct TestApiSet {
1525 call_count: usize,
1526}
1527
1528#[cfg(test)]
1529impl ApiSet for TestApiSet {
1530 fn namespace(&self) -> &str {
1531 "namespace"
1532 }
1533
1534 fn exec_function(
1535 &mut self,
1536 name: &str,
1537 version: i32,
1538 _args: bson::document::Document,
1539 _request_section: Option<RequestCookie>,
1540 ) -> Option<Result<Option<bson::Bson>, ErrorCode>> {
1541 match (name, version) {
1542 ("function", 0) => {
1543 println!("--- namespace::function_0() called");
1544 self.call_count += 1;
1545 }
1546 _ => (),
1547 }
1548 Some(Ok(None))
1549 }
1550}
1551
1552#[test]
1553fn test_honk_timeout() -> anyhow::Result<()> {
1554 let socket_addr = SocketAddr::from(([127, 0, 0, 1], 0u16));
1555 let listener = TcpListener::bind(socket_addr)?;
1556 let socket_addr = listener.local_addr()?;
1557
1558 let alice_stream = TcpStream::connect(socket_addr)?;
1559 alice_stream.set_nonblocking(true)?;
1560 alice_stream.set_nodelay(true)?;
1561 println!("--- alice peer_addr: {}", alice_stream.peer_addr()?);
1562 let (pat_stream, _socket_addr) = listener.accept()?;
1563 pat_stream.set_nonblocking(true)?;
1564 pat_stream.set_nodelay(true)?;
1565
1566 let mut alice = Session::new(alice_stream);
1567 let mut alice_apiset = TestApiSet { call_count: 0usize };
1568 let mut pat = Session::new(pat_stream);
1569
1570 let start = std::time::Instant::now();
1571
1572 println!(
1573 "--- {:?} alice set max_wait_time to 3 seconds",
1574 std::time::Instant::now().duration_since(start)
1575 );
1576 alice.update(None)?;
1577 alice.set_max_wait_time(std::time::Duration::from_secs(3));
1578 alice.update(None)?;
1579
1580 println!(
1582 "--- {:?} sleep 2 seconds",
1583 std::time::Instant::now().duration_since(start)
1584 );
1585 std::thread::sleep(std::time::Duration::from_secs(2));
1586
1587 println!(
1588 "--- {:?} pat calls namespace::function_0()",
1589 std::time::Instant::now().duration_since(start)
1590 );
1591 pat.client_call("namespace", "function", 0, doc! {})?;
1592 while alice_apiset.call_count != 1 {
1593 pat.update(None)?;
1594 alice.update(Some(&mut [&mut alice_apiset]))?;
1595 }
1596
1597 println!(
1599 "--- {:?} sleep 2 seconds",
1600 std::time::Instant::now().duration_since(start)
1601 );
1602 std::thread::sleep(std::time::Duration::from_secs(2));
1603 pat.update(None)?;
1604 alice.update(None)?;
1605
1606 println!(
1607 "--- {:?} pat calls namespace::function_0()",
1608 std::time::Instant::now().duration_since(start)
1609 );
1610 pat.client_call("namespace", "function", 0, doc! {})?;
1611 while alice_apiset.call_count != 2 {
1612 pat.update(None)?;
1613 alice.update(Some(&mut [&mut alice_apiset]))?;
1614 }
1615
1616 println!(
1618 "--- {:?} sleep 4 seconds",
1619 std::time::Instant::now().duration_since(start)
1620 );
1621 std::thread::sleep(std::time::Duration::from_secs(4));
1622
1623 println!(
1624 "--- {:?} pat+alice update",
1625 std::time::Instant::now().duration_since(start)
1626 );
1627 pat.update(None)?;
1628 match alice.update(None) {
1629 Ok(()) => panic!("should have timed out"),
1630 Err(Error::MessageReadTimedOut(duration)) => {
1631 println!("--- expected time out after {:?}", duration)
1632 }
1633 Err(err) => panic!("unexpected error: {:?}", err),
1634 }
1635 Ok(())
1636}