1use std::fmt::{Debug, Formatter};
2use std::marker::PhantomData;
3
4use proc_macro2::Span;
5use quote::quote;
6use stageleft::runtime_support::{FreeVariableWithContext, QuoteTokens};
7use stageleft::{QuotedWithContext, quote_type};
8
9use super::dynamic::LocationId;
10use super::{Location, MemberId};
11use crate::compile::builder::FlowState;
12use crate::location::member_id::TaglessMemberId;
13use crate::staging_util::{Invariant, get_this_crate};
14
15pub struct Cluster<'a, ClusterTag> {
16 pub(crate) id: usize,
17 pub(crate) flow_state: FlowState,
18 pub(crate) _phantom: Invariant<'a, ClusterTag>,
19}
20
21impl<C> Debug for Cluster<'_, C> {
22 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
23 write!(f, "Cluster({})", self.id)
24 }
25}
26
27impl<C> Eq for Cluster<'_, C> {}
28impl<C> PartialEq for Cluster<'_, C> {
29 fn eq(&self, other: &Self) -> bool {
30 self.id == other.id && FlowState::ptr_eq(&self.flow_state, &other.flow_state)
31 }
32}
33
34impl<C> Clone for Cluster<'_, C> {
35 fn clone(&self) -> Self {
36 Cluster {
37 id: self.id,
38 flow_state: self.flow_state.clone(),
39 _phantom: PhantomData,
40 }
41 }
42}
43
44impl<'a, C> super::dynamic::DynLocation for Cluster<'a, C> {
45 fn id(&self) -> LocationId {
46 LocationId::Cluster(self.id)
47 }
48
49 fn flow_state(&self) -> &FlowState {
50 &self.flow_state
51 }
52
53 fn is_top_level() -> bool {
54 true
55 }
56}
57
58impl<'a, C> Location<'a> for Cluster<'a, C> {
59 type Root = Cluster<'a, C>;
60
61 fn root(&self) -> Self::Root {
62 self.clone()
63 }
64
65 fn name() -> String {
66 format!("Cluster<{}>", std::any::type_name::<C>())
67 }
68}
69
70pub struct ClusterIds<'a> {
71 pub id: usize,
72 pub _phantom: PhantomData<&'a ()>,
73}
74
75impl<'a> Clone for ClusterIds<'a> {
76 fn clone(&self) -> Self {
77 Self {
78 id: self.id,
79 _phantom: Default::default(),
80 }
81 }
82}
83
84impl<'a, Ctx> FreeVariableWithContext<Ctx> for ClusterIds<'a> {
85 type O = &'a [TaglessMemberId];
86
87 fn to_tokens(self, _ctx: &Ctx) -> QuoteTokens
88 where
89 Self: Sized,
90 {
91 let ident = syn::Ident::new(
92 &format!("__hydro_lang_cluster_ids_{}", self.id),
93 Span::call_site(),
94 );
95
96 QuoteTokens {
97 prelude: None,
98 expr: Some(quote! { #ident }),
99 }
100 }
101}
102
103impl<'a, Ctx> QuotedWithContext<'a, &'a [TaglessMemberId], Ctx> for ClusterIds<'a> {}
104
105pub trait IsCluster {
106 type Tag;
107}
108
109impl<C> IsCluster for Cluster<'_, C> {
110 type Tag = C;
111}
112
113pub static CLUSTER_SELF_ID: ClusterSelfId = ClusterSelfId { _private: &() };
116
117#[derive(Clone, Copy)]
118pub struct ClusterSelfId<'a> {
119 _private: &'a (),
120}
121
122impl<'a, L> FreeVariableWithContext<L> for ClusterSelfId<'a>
123where
124 L: Location<'a>,
125 <L as Location<'a>>::Root: IsCluster,
126{
127 type O = MemberId<<<L as Location<'a>>::Root as IsCluster>::Tag>;
128
129 fn to_tokens(self, ctx: &L) -> QuoteTokens
130 where
131 Self: Sized,
132 {
133 let cluster_id = if let LocationId::Cluster(id) = ctx.root().id() {
134 id
135 } else {
136 unreachable!()
137 };
138
139 let ident = syn::Ident::new(
140 &format!("__hydro_lang_cluster_self_id_{}", cluster_id),
141 Span::call_site(),
142 );
143 let root = get_this_crate();
144 let c_type: syn::Type = quote_type::<<<L as Location<'a>>::Root as IsCluster>::Tag>();
145
146 QuoteTokens {
147 prelude: None,
148 expr: Some(
149 quote! { #root::location::MemberId::<#c_type>::from_tagless((#ident).clone()) },
150 ),
151 }
152 }
153}
154
155impl<'a, L> QuotedWithContext<'a, MemberId<<<L as Location<'a>>::Root as IsCluster>::Tag>, L>
156 for ClusterSelfId<'a>
157where
158 L: Location<'a>,
159 <L as Location<'a>>::Root: IsCluster,
160{
161}
162
163#[cfg(test)]
164mod tests {
165 #[cfg(feature = "sim")]
166 use stageleft::q;
167
168 #[cfg(feature = "sim")]
169 use super::CLUSTER_SELF_ID;
170 #[cfg(feature = "sim")]
171 use crate::location::{Location, MemberId, MembershipEvent};
172 #[cfg(feature = "sim")]
173 use crate::nondet::nondet;
174 #[cfg(feature = "sim")]
175 use crate::prelude::FlowBuilder;
176
177 #[cfg(feature = "sim")]
178 #[test]
179 fn sim_cluster_self_id() {
180 let flow = FlowBuilder::new();
181 let cluster1 = flow.cluster::<()>();
182 let cluster2 = flow.cluster::<()>();
183
184 let node = flow.process::<()>();
185
186 let out_recv = cluster1
187 .source_iter(q!(vec![CLUSTER_SELF_ID]))
188 .send_bincode(&node)
189 .values()
190 .interleave(
191 cluster2
192 .source_iter(q!(vec![CLUSTER_SELF_ID]))
193 .send_bincode(&node)
194 .values(),
195 )
196 .sim_output();
197
198 flow.sim()
199 .with_cluster_size(&cluster1, 3)
200 .with_cluster_size(&cluster2, 4)
201 .exhaustive(async || {
202 out_recv
203 .assert_yields_only_unordered([0, 1, 2, 0, 1, 2, 3].map(MemberId::from_raw_id))
204 .await
205 });
206 }
207
208 #[cfg(feature = "sim")]
209 #[test]
210 fn sim_cluster_with_tick() {
211 use std::collections::HashMap;
212
213 let flow = FlowBuilder::new();
214 let cluster = flow.cluster::<()>();
215 let node = flow.process::<()>();
216
217 let out_recv = cluster
218 .source_iter(q!(vec![1, 2, 3]))
219 .batch(&cluster.tick(), nondet!())
220 .count()
221 .all_ticks()
222 .send_bincode(&node)
223 .entries()
224 .map(q!(|(id, v)| (id, v)))
225 .sim_output();
226
227 let count = flow
228 .sim()
229 .with_cluster_size(&cluster, 2)
230 .exhaustive(async || {
231 let grouped = out_recv.collect_sorted::<Vec<_>>().await.into_iter().fold(
232 HashMap::new(),
233 |mut acc: HashMap<MemberId<()>, usize>, (id, v)| {
234 *acc.entry(id).or_default() += v;
235 acc
236 },
237 );
238
239 assert!(grouped.len() == 2);
240 for (_id, v) in grouped {
241 assert!(v == 3);
242 }
243 });
244
245 assert_eq!(count, 106);
246 }
250
251 #[cfg(feature = "sim")]
252 #[test]
253 fn sim_cluster_membership() {
254 let flow = FlowBuilder::new();
255 let cluster = flow.cluster::<()>();
256 let node = flow.process::<()>();
257
258 let out_recv = node
259 .source_cluster_members(&cluster)
260 .entries()
261 .map(q!(|(id, v)| (id, v)))
262 .sim_output();
263
264 flow.sim()
265 .with_cluster_size(&cluster, 2)
266 .exhaustive(async || {
267 out_recv
268 .assert_yields_only_unordered(vec![
269 (MemberId::from_raw_id(0), MembershipEvent::Joined),
270 (MemberId::from_raw_id(1), MembershipEvent::Joined),
271 ])
272 .await;
273 });
274 }
275}